From 25f9388cc27a782c46996616f687a3166768760c Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 20 Sep 2022 07:55:22 +0200 Subject: [PATCH 01/76] first_evolutionary_draft --- README.md | 12 +++++ minisd.py | 41 +++++++++++++++ multi_minisd.sh | 96 ++++++++++++++++++++++++++++++++++++ pipeline_stable_diffusion.py | 1 + 4 files changed, 150 insertions(+) create mode 100644 minisd.py create mode 100755 multi_minisd.sh create mode 120000 pipeline_stable_diffusion.py diff --git a/README.md b/README.md index c9e6c3bb1..f5a0e44c1 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,15 @@ +# Modified evolutionary version. + +Install as usual (below). +Then use the "pipeline_stable_diffusion.py" in lieu of the original pipeline_stable_diffusion.py found at +<< python -c "import diffusers ; print(diffusers.__file__)" >>. +Then edit the prompt in minisd.py. +Then + ./multi_minisd.sh + + + + # Stable Diffusion *Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:* diff --git a/minisd.py b/minisd.py new file mode 100644 index 000000000..5d1029f47 --- /dev/null +++ b/minisd.py @@ -0,0 +1,41 @@ +import random +import torch +from torch import autocast +from diffusers import StableDiffusionPipeline + +model_id = "CompVis/stable-diffusion-v1-4" +#device = "cuda" +device = "mps" #torch.device("mps") + + +pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token="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") +pipe = pipe.to(device) + +prompt = "a photo of an astronaut riding a horse on mars" +prompt = "a photo of a red panda with a hat playing table tennis" +prompt = "a photorealistic portrait of " + random.choice(["Mary Cury", "Scarlett Johansson", "Marilyn Monroe", "Poison Ivy", "Black Widow", "Medusa", "Batman", "Albert Einstein", "Louis XIV", "Tarzan"]) + random.choice([" with glasses", " with a hat", " with a cigarette", "with a scarf"]) +prompt = "a photorealistic portrait of " + random.choice(["Nelson Mandela", "Superman", "Superwoman", "Volodymyr Zelenskyy", "Tsai Ing-Wen", "Lzzy Hale", "Meg Myers"]) + random.choice([" with glasses", " with a hat", " with a cigarette", "with a scarf"]) +prompt = random.choice(["A woman with three eyes", "Meg Myers", "The rock band Ankor", "Miley Cyrus", "The man named Rahan", "A murder", "Rambo playing table tennis"]) +prompt = "Photo of a female Terminator." +prompt = random.choice([ + "Photo of Tarzan as a lawyer with a tie", + "Photo of Scarlett Johansson as a sumo-tori", + "Photo of the little mermaid as a young black girl", + "Photo of Schwarzy with tentacles", + "Photo of Meg Myers with an Egyptian dress", + "Photo of Schwarzy as a ballet dancer", + ]) + +name = random.choice(["Mark Zuckerbeg", "Zendaya", "Yann LeCun", "Scarlett Johansson", "Superman", "Meg Myers"]) +name = "Zendaya" +prompt = f"Photo of {name} as a sumo-tori." +with autocast("cuda"): + image = pipe(prompt, guidance_scale=7.5)["sample"][0] + +sentinel = random.randint(0,100000) +image.save(f"SD_{prompt.replace(' ','_')}_image_{sentinel}.png") +import os +latent = eval((os.environ["latent_sd"])) +with open(f"SD_{prompt.replace(' ','_')}_latent_{sentinel}.txt", 'w') as f: + f.write(f"{latent}") + diff --git a/multi_minisd.sh b/multi_minisd.sh new file mode 100755 index 000000000..a73a44538 --- /dev/null +++ b/multi_minisd.sh @@ -0,0 +1,96 @@ +#!/bin/bash + +set -e -x + +touch empty_file +rm empty_file +touch empty_file +touch SD_prout_${random}.png +touch SD_prout_${random}.txt +mv SD_*.png poubelle/ +mv SD_*.txt poubelle/ +# Initialization: run SD and create an image, with rank 1. +touch goodbad.py +rm goodbad.py +touch goodbad.py +echo "good = []" >> goodbad.py +echo "bad = []" >> goodbad.py +python minisd.py + #sentinel=${RANDOM} + #touch SD_image_${sentinel}.png + #touch SD_latent_${sentinel}.txt +mylist="`ls -ctr SD*_image_*.png | tail -n 1`" +myranks=1 + +for i in `seq 30` +do + # Now an iteration. + echo Current images = $mylist + echo Current ranks = $myranks + #sentinel=${RANDOM} + #touch SD_image_${sentinel}.png + #touch SD_latent_${sentinel}.txt + echo GENERATING FOUR IMAGES. + python minisd.py + python minisd.py + python minisd.py + python minisd.py + for img in `ls -ctr SD*_image_*.png | tail -n 4` + do + montage $mylist $img -mode Concatenate -tile 5x output.png + open --wait output.png + read -p "Rank of the last image ?" rank + mylist="$mylist $img` + mynewranks="" + for r in $myranks + do + [[ $r -ge $rank ]] && r=$(( $r + 1 )) + mynewranks="$mynewranks $r" + done + myranks="$mynewranks $rank" + #echo Before sorting =========================== + #echo $myranks + #echo $mylist + #sleep 2 + + # Now sorting + mynewlist="" + mynewranks="" + touch goodbad.py + rm goodbad.py + touch goodbad.py + echo "good = []" >> goodbad.py + echo "bad = []" >> goodbad.py + + for r in `seq 20` + do + for k in `seq 20` + do + [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] && echo "FOUND $k for $r!" + my_image="`echo $mylist | cut -d ' ' -f $k`" + [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] && mynewranks="$mynewranks `echo $myranks | cut -d ' ' -f $k`" + [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] && mynewlist="$mynewlist `echo $mylist | cut -d ' ' -f $k`" + if [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] + then + echo Found $my_image at rank $k for $r + if [[ $r -le 5 ]] + then + cat empty_file `echo $my_image | sed 's/image_[0-9]*.png/PROUTPROUT&/g' | sed 's/PROUTPROUTimage/latent/g' | sed 's/\.png/\.txt/g'` | sed "s/.*/good += [&]/g" >> goodbad.py + else + cat empty_file `echo $my_image | sed 's/image_[0-9]*.png/PROUTPROUT&/g' | sed 's/PROUTPROUTimage/latent/g' | sed 's/\.png/\.txt/g'` | sed "s/.*/bad += [&]/g" >> goodbad.py + fi + fi + echo "" >> goodbad.py + done + done + done + myranks=$mynewranks + mylist=$mynewlist + echo After sorting =========================== + echo $myranks + echo $mylist + #sleep 2 + +done + + diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py new file mode 120000 index 000000000..c9710a945 --- /dev/null +++ b/pipeline_stable_diffusion.py @@ -0,0 +1 @@ +/opt/miniconda3/envs/sd_gpu/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py \ No newline at end of file From d9cdb96e6742e45568d3b0bd836bf5d04ebc1122 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 20 Sep 2022 16:00:31 +0200 Subject: [PATCH 02/76] fix --- pipeline_stable_diffusion.py | 334 ++++++++++++++++++++++++++++++++++- 1 file changed, 333 insertions(+), 1 deletion(-) mode change 120000 => 100644 pipeline_stable_diffusion.py diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py deleted file mode 120000 index c9710a945..000000000 --- a/pipeline_stable_diffusion.py +++ /dev/null @@ -1 +0,0 @@ -/opt/miniconda3/envs/sd_gpu/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py \ No newline at end of file diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py new file mode 100644 index 000000000..3b52f0a99 --- /dev/null +++ b/pipeline_stable_diffusion.py @@ -0,0 +1,333 @@ +import inspect +import os +import random +import warnings +from typing import List, Optional, Union + +import torch + +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +class StableDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_device = "cpu" if self.device.type == "mps" else self.device + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + if latents is None: + latents = torch.randn( + latents_shape, + generator=generator, + device=latents_device, + ) + + from goodbad import good + from goodbad import bad + i_believe_in_evolution = len(good) > 0 and len(bad) > 0 + print(f"I believe in evolution = {i_believe_in_evolution}") + if i_believe_in_evolution: + from sklearn import tree + from sklearn.neural_network import MLPClassifier + #from sklearn.neighbors import NearestCentroid + from sklearn.linear_model import LogisticRegression + import numpy as np + #z = (np.random.randn(4*64*64)) + z = latents.cpu().numpy().flatten() + + #clf=tree.DecisionTreeClassifier()#min_samples_split=0.1) + clf = MLPClassifier(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(5, 2), random_state=1) + #clf = NearestCentroid() + #clf = LogisticRegression() + + + + X=good + bad + Y = [1] * len(good) + [0] * len(bad) + clf = clf.fit(X,Y) + epsilon = 0.0001 # for astronauts + epsilon = 1.0 + + def loss(x): + return clf.predict_proba([(1-epsilon)*z+epsilon*x])[0][0] # for astronauts + #return clf.predict_proba([z+epsilon*x])[0][0] + + + import nevergrad as ng + budget = 300 + z = np.array(random.choice(good)) + if i_believe_in_evolution: + nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) + #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) + #nevergrad_optimizer = ng.optimizers.DiscreteLenglerOnePlusOne(len(z), 10000) + #nevergrad_optimizer = ng.optimizers.DiscreteOnePlusOne(len(z), 10000) + #nevergrad_optimizer.suggest(z) + + for i in range(budget): + x = nevergrad_optimizer.ask() + l = loss(x.value) + if np.log2(i+1) == int(np.log2(i+1)): + print(f"iteration {i} --> {l}") + print("var/variable = ", sum(((1-epsilon)*z + epsilon * x.value)**2)/len(x.value)) + if l < 0.0000001: + print(f"we find proba(bad)={l}") + break + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + os.environ["latent_sd"] = str(list(latents.flatten().cpu().numpy())) + latents = latents.to(self.device) + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[i] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + # run safety checker + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From dd1299faee47272544b5394e259f7962ce3bffcb Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 20 Sep 2022 16:20:08 +0200 Subject: [PATCH 03/76] fix --- minisd.py | 2 ++ multi_minisd.sh | 37 +++++++++++++++++++++++-------------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/minisd.py b/minisd.py index 5d1029f47..99d671fd3 100644 --- a/minisd.py +++ b/minisd.py @@ -29,6 +29,8 @@ name = random.choice(["Mark Zuckerbeg", "Zendaya", "Yann LeCun", "Scarlett Johansson", "Superman", "Meg Myers"]) name = "Zendaya" prompt = f"Photo of {name} as a sumo-tori." + +prompt = "A close up photographic portrait of a young woman." with autocast("cuda"): image = pipe(prompt, guidance_scale=7.5)["sample"][0] diff --git a/multi_minisd.sh b/multi_minisd.sh index a73a44538..db2b13371 100755 --- a/multi_minisd.sh +++ b/multi_minisd.sh @@ -1,6 +1,6 @@ #!/bin/bash -set -e -x +set -e touch empty_file rm empty_file @@ -31,16 +31,25 @@ do #touch SD_image_${sentinel}.png #touch SD_latent_${sentinel}.txt echo GENERATING FOUR IMAGES. + cat goodbad.py | awk '!x[$0]++' > goodbad2.py + mv goodbad2.py goodbad.py python minisd.py python minisd.py python minisd.py python minisd.py - for img in `ls -ctr SD*_image_*.png | tail -n 4` + list_of_four_images="`ls -ctr SD*_image_*.png | tail -n 4`" + for img in $list_of_four_images do + echo We add image $img ======================= montage $mylist $img -mode Concatenate -tile 5x output.png open --wait output.png read -p "Rank of the last image ?" rank - mylist="$mylist $img` + echo "Provided rank: $rank" + mylist="$mylist $img" + if [[ $rank -le 0 ]] + then + read -p "Enter all ranks !!!!" myranks + else mynewranks="" for r in $myranks do @@ -48,20 +57,19 @@ do mynewranks="$mynewranks $r" done myranks="$mynewranks $rank" + fi #echo Before sorting =========================== #echo $myranks #echo $mylist - #sleep 2 + #sleep 5 # Now sorting mynewlist="" mynewranks="" - touch goodbad.py - rm goodbad.py - touch goodbad.py - echo "good = []" >> goodbad.py - echo "bad = []" >> goodbad.py - + sed -i.backup 's/good +=.*//g' goodbad.py + num_good=`cat goodbad.py | grep 'good +=' | wc -l ` + num_good=$(( $num_good / 2 + 5 )) + echo "We keep the $num_good best." for r in `seq 20` do for k in `seq 20` @@ -73,19 +81,20 @@ do if [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] then echo Found $my_image at rank $k for $r - if [[ $r -le 5 ]] + if [[ $r -le $num_good ]] then cat empty_file `echo $my_image | sed 's/image_[0-9]*.png/PROUTPROUT&/g' | sed 's/PROUTPROUTimage/latent/g' | sed 's/\.png/\.txt/g'` | sed "s/.*/good += [&]/g" >> goodbad.py else cat empty_file `echo $my_image | sed 's/image_[0-9]*.png/PROUTPROUT&/g' | sed 's/PROUTPROUTimage/latent/g' | sed 's/\.png/\.txt/g'` | sed "s/.*/bad += [&]/g" >> goodbad.py fi + echo "" >> goodbad.py + break fi - echo "" >> goodbad.py done done + myranks=$mynewranks + mylist=$mynewlist done - myranks=$mynewranks - mylist=$mynewlist echo After sorting =========================== echo $myranks echo $mylist From e8d06f23405b5fa5e3da989f477ccbcf901acbe3 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 20 Sep 2022 16:40:49 +0200 Subject: [PATCH 04/76] fix --- pipeline_stable_diffusion.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index 3b52f0a99..8588be2cf 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -258,12 +258,16 @@ def loss(x): for i in range(budget): x = nevergrad_optimizer.ask() l = loss(x.value) + nevergrad_optimizer.tell(x, l) if np.log2(i+1) == int(np.log2(i+1)): print(f"iteration {i} --> {l}") print("var/variable = ", sum(((1-epsilon)*z + epsilon * x.value)**2)/len(x.value)) + x = nevergrad_optimizer.recommend().value + z = (1.-epsilon) * z + epsilon * x if l < 0.0000001: print(f"we find proba(bad)={l}") break + latents = torch.from_numpy(z.reshape(latents_shape)).half() else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") From 29672dbdf03dd961ce81bfa26d36971b1e02aa03 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 20 Sep 2022 16:41:45 +0200 Subject: [PATCH 05/76] fix --- pipeline_stable_diffusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index 8588be2cf..abb5e63d0 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -253,7 +253,7 @@ def loss(x): #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) #nevergrad_optimizer = ng.optimizers.DiscreteLenglerOnePlusOne(len(z), 10000) #nevergrad_optimizer = ng.optimizers.DiscreteOnePlusOne(len(z), 10000) - #nevergrad_optimizer.suggest(z) + nevergrad_optimizer.suggest(z) for i in range(budget): x = nevergrad_optimizer.ask() @@ -264,9 +264,9 @@ def loss(x): print("var/variable = ", sum(((1-epsilon)*z + epsilon * x.value)**2)/len(x.value)) x = nevergrad_optimizer.recommend().value z = (1.-epsilon) * z + epsilon * x - if l < 0.0000001: - print(f"we find proba(bad)={l}") - break + #if l < 0.0000001: + # print(f"we find proba(bad)={l}") + # break latents = torch.from_numpy(z.reshape(latents_shape)).half() else: if latents.shape != latents_shape: From 895ad559eb7ba6ca15e7975a3e877fa1773338d3 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 20 Sep 2022 17:16:50 +0200 Subject: [PATCH 06/76] fix --- multi_minisd.sh | 13 +++++++------ pipeline_stable_diffusion.py | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/multi_minisd.sh b/multi_minisd.sh index db2b13371..7eee1cd44 100755 --- a/multi_minisd.sh +++ b/multi_minisd.sh @@ -30,14 +30,15 @@ do #sentinel=${RANDOM} #touch SD_image_${sentinel}.png #touch SD_latent_${sentinel}.txt - echo GENERATING FOUR IMAGES. + lambda=7 + echo "GENERATING $lambda IMAGES ================================" cat goodbad.py | awk '!x[$0]++' > goodbad2.py mv goodbad2.py goodbad.py - python minisd.py - python minisd.py - python minisd.py - python minisd.py - list_of_four_images="`ls -ctr SD*_image_*.png | tail -n 4`" + for kk in `seq $lambda` + do + python minisd.py + done + list_of_four_images="`ls -ctr SD*_image_*.png | tail -n $lambda`" for img in $list_of_four_images do echo We add image $img ======================= diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index abb5e63d0..45e27b557 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -267,7 +267,7 @@ def loss(x): #if l < 0.0000001: # print(f"we find proba(bad)={l}") # break - latents = torch.from_numpy(z.reshape(latents_shape)).half() + latents = torch.from_numpy(z.reshape(latents_shape)).float() #.half() else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") From a977845c5d7ddbb2bdd1bada71cf37c5d9195f99 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 20 Sep 2022 17:51:28 +0200 Subject: [PATCH 07/76] fix --- multi_minisd.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/multi_minisd.sh b/multi_minisd.sh index 7eee1cd44..bd8bc52c0 100755 --- a/multi_minisd.sh +++ b/multi_minisd.sh @@ -32,6 +32,8 @@ do #touch SD_latent_${sentinel}.txt lambda=7 echo "GENERATING $lambda IMAGES ================================" + echo "`grep -c 'good +=' goodbad.py` positive examples" + echo "`grep -c 'bad +=' goodbad.py` negative examples" cat goodbad.py | awk '!x[$0]++' > goodbad2.py mv goodbad2.py goodbad.py for kk in `seq $lambda` From d376f3f43afc490a7508c3040a18839b217c9446 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 20 Sep 2022 18:20:08 +0200 Subject: [PATCH 08/76] fix --- multi_minisd.sh | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/multi_minisd.sh b/multi_minisd.sh index bd8bc52c0..46b5ae26a 100755 --- a/multi_minisd.sh +++ b/multi_minisd.sh @@ -32,15 +32,21 @@ do #touch SD_latent_${sentinel}.txt lambda=7 echo "GENERATING $lambda IMAGES ================================" - echo "`grep -c 'good +=' goodbad.py` positive examples" - echo "`grep -c 'bad +=' goodbad.py` negative examples" cat goodbad.py | awk '!x[$0]++' > goodbad2.py mv goodbad2.py goodbad.py + echo "`grep -c 'good +=' goodbad.py` positive examples" + echo "`grep -c 'bad +=' goodbad.py` negative examples" for kk in `seq $lambda` do python minisd.py done list_of_four_images="`ls -ctr SD*_image_*.png | tail -n $lambda`" + my_new_list="" + # We stop at 19 so that it becomes 20 with the new one + for k in `seq 19` + do + my_new_list="$my_new_list `echo $mylist | cut -d ' ' -f $k`" + done for img in $list_of_four_images do echo We add image $img ======================= From 6d2723e13c2ebcc22a28510430a9a150c7bbd3b1 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 20 Sep 2022 18:29:34 +0200 Subject: [PATCH 09/76] fix --- pipeline_stable_diffusion.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index 45e27b557..83aebf928 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -245,15 +245,24 @@ def loss(x): #return clf.predict_proba([z+epsilon*x])[0][0] - import nevergrad as ng - budget = 300 - z = np.array(random.choice(good)) if i_believe_in_evolution: - nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) + import nevergrad as ng + budget = 300 #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) - #nevergrad_optimizer = ng.optimizers.DiscreteLenglerOnePlusOne(len(z), 10000) - #nevergrad_optimizer = ng.optimizers.DiscreteOnePlusOne(len(z), 10000) - nevergrad_optimizer.suggest(z) + #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) + nevergrad_optimizer = ng.optimizers.DiscreteLenglerOnePlusOne(len(z), budget) + #nevergrad_optimizer = ng.optimizers.DiscreteOnePlusOne(len(z), budget) + for k in range(5): + z1 = np.array(random.choice(good)) + z2 = np.array(random.choice(good)) + z3 = np.array(random.choice(good)) + z4 = np.array(random.choice(good)) + z5 = np.array(random.choice(good)) + z = z1 + for u in range(len(z)): + z[u] = random.choice([z1[u],z2[u],z3[u],z4[u],z5[u]]) + nevergrad_optimizer.suggest(z) + for i in range(budget): x = nevergrad_optimizer.ask() From 564aa36b5a9cb12bf174abe88deced0f7669630c Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Wed, 21 Sep 2022 08:34:10 +0200 Subject: [PATCH 10/76] fix --- view_history.sh | 2 ++ 1 file changed, 2 insertions(+) create mode 100755 view_history.sh diff --git a/view_history.sh b/view_history.sh new file mode 100755 index 000000000..26b54c393 --- /dev/null +++ b/view_history.sh @@ -0,0 +1,2 @@ +montage `ls -ctr SD*imag*.png` -mode concatenate -tile 7x history.png +open history.png From 728fb9399d12928f69e2ae920f666b0f3309970c Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Wed, 21 Sep 2022 08:40:28 +0200 Subject: [PATCH 11/76] fix --- minisd.py | 2 +- multi_minisd.sh | 23 ++++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/minisd.py b/minisd.py index 99d671fd3..d1443d8c6 100644 --- a/minisd.py +++ b/minisd.py @@ -30,7 +30,7 @@ name = "Zendaya" prompt = f"Photo of {name} as a sumo-tori." -prompt = "A close up photographic portrait of a young woman." +prompt = "A close up photographic portrait of a young woman with colored hair." with autocast("cuda"): image = pipe(prompt, guidance_scale=7.5)["sample"][0] diff --git a/multi_minisd.sh b/multi_minisd.sh index 46b5ae26a..cc707627c 100755 --- a/multi_minisd.sh +++ b/multi_minisd.sh @@ -30,7 +30,7 @@ do #sentinel=${RANDOM} #touch SD_image_${sentinel}.png #touch SD_latent_${sentinel}.txt - lambda=7 + lambda=4 echo "GENERATING $lambda IMAGES ================================" cat goodbad.py | awk '!x[$0]++' > goodbad2.py mv goodbad2.py goodbad.py @@ -38,20 +38,27 @@ do echo "`grep -c 'bad +=' goodbad.py` negative examples" for kk in `seq $lambda` do + echo "generating image $kk / $lambda" python minisd.py done list_of_four_images="`ls -ctr SD*_image_*.png | tail -n $lambda`" - my_new_list="" - # We stop at 19 so that it becomes 20 with the new one - for k in `seq 19` - do - my_new_list="$my_new_list `echo $mylist | cut -d ' ' -f $k`" - done +# my_new_list="" +# my_new_ranks="" +# # We stop at 19 so that it becomes 20 with the new one +# for k in `seq 19` +# do +# my_new_list="$my_new_list `echo $mylist | cut -d ' ' -f $k`" +# my_new_ranks="$my_new_ranks `echo $myranks | cut -d ' ' -f $k`" +# done +# mylist=`echo $my_new_list | sed 's/[ ]*$//g'` +# myranks=`echo $my_new_ranks | sed 's/[ ]*$//g'` +# echo "After limiting to 19, we get $mylist and $myranks " for img in $list_of_four_images do echo We add image $img ======================= montage $mylist $img -mode Concatenate -tile 5x output.png open --wait output.png + # read -t 1 prout read -p "Rank of the last image ?" rank echo "Provided rank: $rank" mylist="$mylist $img" @@ -77,7 +84,9 @@ do mynewranks="" sed -i.backup 's/good +=.*//g' goodbad.py num_good=`cat goodbad.py | grep 'good +=' | wc -l ` + echo "Num goods in file: $num_good" num_good=$(( $num_good / 2 + 5 )) + echo "Num goods after update: $num_good" echo "We keep the $num_good best." for r in `seq 20` do From 51d4c05d1c287f985d395c18a38cb778010a536a Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 22 Sep 2022 16:34:41 +0200 Subject: [PATCH 12/76] fix --- archimulti_minisd.sh | 15 +++++ edit.sh | 3 + inoculate_evo_sd.sh | 131 +++++++++++++++++++++++++++++++++++++++++++ maketgz.sh | 2 + minisd.sh | 11 ++++ multi_minisd.sh | 5 +- multiminisd.sh | 40 +++++++++++++ view_history.sh | 2 +- 8 files changed, 207 insertions(+), 2 deletions(-) create mode 100755 archimulti_minisd.sh create mode 100755 edit.sh create mode 100755 inoculate_evo_sd.sh create mode 100755 maketgz.sh create mode 100755 minisd.sh create mode 100755 multiminisd.sh diff --git a/archimulti_minisd.sh b/archimulti_minisd.sh new file mode 100755 index 000000000..12e1b2180 --- /dev/null +++ b/archimulti_minisd.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +for u in `seq 25` +do +export prompt="A photographic portrait of a young woman, tilted head, from below, red hair, green eyes, cleavage, smiling." +python minisd.py +export prompt="A photo of a cute armoured bloody red panda fighting off tentacles with daggers." +python minisd.py +export prompt="A photo of a woman fighting off tentacles with guns." +python minisd.py +export prompt="A cute armoured red panda fighting off zombies with karate." +python minisd.py +export prompt="An armored Mark Zuckerberg fighting off bloody tentacles in the jungle." +python minisd.py +done diff --git a/edit.sh b/edit.sh new file mode 100755 index 000000000..b8c49801b --- /dev/null +++ b/edit.sh @@ -0,0 +1,3 @@ +#vim /private/home/oteytaud/.conda/envs/sd/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +vim /opt/miniconda3/envs/sd_gpu/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +cp /opt/miniconda3/envs/sd_gpu/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py . diff --git a/inoculate_evo_sd.sh b/inoculate_evo_sd.sh new file mode 100755 index 000000000..256995f68 --- /dev/null +++ b/inoculate_evo_sd.sh @@ -0,0 +1,131 @@ +#!/bin/bash + +echo Parametrization and initialization. +#export prompt="A close up photographic portrait of a young woman with uniformly colored hair." +export prompt="An armored Mark Zuckerberg fighting off bloody tentacles in the jungle." +lambda=18 +cp basic_inoculation_uniformly/SD*.* inoculations/ + + +set -e + +touch empty_file +rm empty_file +touch empty_file +touch SD_prout_${random}.png +touch SD_prout_${random}.txt +mv SD_*.png poubelle/ +mv SD_*.txt poubelle/ +# Initialization: run SD and create an image, with rank 1. +touch goodbad.py +rm goodbad.py +touch goodbad.py +echo "good = []" >> goodbad.py +echo "bad = []" >> goodbad.py +#python minisd.py +./minisd.sh + #sentinel=${RANDOM} + #touch SD_image_${sentinel}.png + #touch SD_latent_${sentinel}.txt +mylist="`ls -ctr SD*_image_*.png | tail -n 1`" +myranks=1 + +for i in `seq 30` +do + # Now an iteration. + echo Current images = $mylist + echo Current ranks = $myranks + #sentinel=${RANDOM} + #touch SD_image_${sentinel}.png + #touch SD_latent_${sentinel}.txt + echo "GENERATING $lambda IMAGES ================================" + cat goodbad.py | awk '!x[$0]++' > goodbad2.py + mv goodbad2.py goodbad.py + echo "`grep -c 'good +=' goodbad.py` positive examples" + echo "`grep -c 'bad +=' goodbad.py` negative examples" + for kk in `seq $lambda` + do + echo "generating image $kk / $lambda" + #python minisd.py + ./minisd.sh + done + list_of_four_images="`ls -ctr SD*_image_*.png | tail -n $lambda`" +# my_new_list="" +# my_new_ranks="" +# # We stop at 19 so that it becomes 20 with the new one +# for k in `seq 19` +# do +# my_new_list="$my_new_list `echo $mylist | cut -d ' ' -f $k`" +# my_new_ranks="$my_new_ranks `echo $myranks | cut -d ' ' -f $k`" +# done +# mylist=`echo $my_new_list | sed 's/[ ]*$//g'` +# myranks=`echo $my_new_ranks | sed 's/[ ]*$//g'` +# echo "After limiting to 19, we get $mylist and $myranks " + for img in $list_of_four_images + do + echo We add image $img ======================= + montage $mylist $img -mode Concatenate -tile 5x output.png + open --wait output.png + # read -t 1 prout + read -p "Rank of the last image ?" rank + echo "Provided rank: $rank" + mylist="$mylist $img" + if [[ $rank -le 0 ]] + then + read -p "Enter all ranks !!!!" myranks + else + mynewranks="" + for r in $myranks + do + [[ $r -ge $rank ]] && r=$(( $r + 1 )) + mynewranks="$mynewranks $r" + done + myranks="$mynewranks $rank" + fi + #echo Before sorting =========================== + #echo $myranks + #echo $mylist + #sleep 5 + + # Now sorting + mynewlist="" + mynewranks="" + sed -i.backup 's/good +=.*//g' goodbad.py + num_good=`cat goodbad.py | grep 'good +=' | wc -l ` + echo "Num goods in file: $num_good" + num_good=$(( $num_good / 2 + 5 )) + echo "Num goods after update: $num_good" + echo "We keep the $num_good best." + for r in `seq 20` + do + for k in `seq 20` + do + [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] && echo "FOUND $k for $r!" + my_image="`echo $mylist | cut -d ' ' -f $k`" + [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] && mynewranks="$mynewranks `echo $myranks | cut -d ' ' -f $k`" + [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] && mynewlist="$mynewlist `echo $mylist | cut -d ' ' -f $k`" + if [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] + then + echo Found $my_image at rank $k for $r + if [[ $r -le $num_good ]] + then + cat empty_file `echo $my_image | sed 's/image_[0-9]*.png/PROUTPROUT&/g' | sed 's/PROUTPROUTimage/latent/g' | sed 's/\.png/\.txt/g'` | sed "s/.*/good += [&]/g" >> goodbad.py + else + cat empty_file `echo $my_image | sed 's/image_[0-9]*.png/PROUTPROUT&/g' | sed 's/PROUTPROUTimage/latent/g' | sed 's/\.png/\.txt/g'` | sed "s/.*/bad += [&]/g" >> goodbad.py + fi + echo "" >> goodbad.py + break + fi + done + done + myranks=$mynewranks + mylist=$mynewlist + done + echo After sorting =========================== + echo $myranks + echo $mylist + #sleep 2 + +done + + diff --git a/maketgz.sh b/maketgz.sh new file mode 100755 index 000000000..15e54d619 --- /dev/null +++ b/maketgz.sh @@ -0,0 +1,2 @@ +tar -zcvf ~/bigpack2.tgz `ls -ctrl | grep 'Sep.16' | grep '\.png' | sed 's/.* //g'` | wc -l +ls -ctlr ~/bigpack2.tgz diff --git a/minisd.sh b/minisd.sh new file mode 100755 index 000000000..f5f3da39d --- /dev/null +++ b/minisd.sh @@ -0,0 +1,11 @@ +#!/bin/bash +if compgen -G inoculations/SD*.png > /dev/null ; then +file=`ls -ctr inoculations/SD*.png | tail -n 1` +echo Transformation: +latent=`echo $file | sed 's/image_[0-9]*.png/PROUTPROUT&/g' | sed 's/PROUTPROUTimage/latent/g' | sed 's/\.png/\.txt/g'` +echo Image=$file +echo Latent=$latent +mv $file $latent . +else +python minisd.py +fi diff --git a/multi_minisd.sh b/multi_minisd.sh index cc707627c..a4b312a3e 100755 --- a/multi_minisd.sh +++ b/multi_minisd.sh @@ -2,6 +2,10 @@ set -e +#export prompt="A woman with many eyes." +export prompt="A close up photographic portrait of a young woman with uniformly colored hair." +lambda=14 + touch empty_file rm empty_file touch empty_file @@ -30,7 +34,6 @@ do #sentinel=${RANDOM} #touch SD_image_${sentinel}.png #touch SD_latent_${sentinel}.txt - lambda=4 echo "GENERATING $lambda IMAGES ================================" cat goodbad.py | awk '!x[$0]++' > goodbad2.py mv goodbad2.py goodbad.py diff --git a/multiminisd.sh b/multiminisd.sh new file mode 100755 index 000000000..cc5a283a5 --- /dev/null +++ b/multiminisd.sh @@ -0,0 +1,40 @@ +#!/bin/bash +touch SD.prout.${RANDOM} +mv SD*.* poubelle/ + +numimages=12 +for m in 5 #2 5 3 4 1 +do +export mu=$m +for d in 1 0.5 0 +do +export decay=$d +for ngo in OnePlusOne DiscreteOnePlusOne RandomSearch DiscreteLenglerOnePlusOne +do +export ngoptim=$ngo +for sl in tree nn logit +do +export skl=$sl +for es in True False +do +export earlystop=$es + + +export prompt="A close up photographic portrait of a young woman with uniformly colored hair." +directory=biased_rw_experiment${numimages}_images_${mu}_${ngoptim}_${earlystop}_${skl}_${decay} +mkdir $directory +for u in `seq $numimages` +do +cp goodbad_learnbluehair.py goodbad.py +python minisd.py +./view_history.sh +sleep 1 +done +cp history.png SD* *.py *.sh $directory + +mv SD*.* poubelle/ +done +done +done +done +done diff --git a/view_history.sh b/view_history.sh index 26b54c393..ac81b97b6 100755 --- a/view_history.sh +++ b/view_history.sh @@ -1,2 +1,2 @@ -montage `ls -ctr SD*imag*.png` -mode concatenate -tile 7x history.png +montage `ls -ctr SD*imag*.png | tail -n 18` -mode concatenate -tile 4x history.png open history.png From b4d74c5a0f85620c8269176000dc640870406fd0 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 22 Sep 2022 16:46:00 +0200 Subject: [PATCH 13/76] fix --- archimulti_minisd.sh | 2 +- minisd.py | 17 ++++++++-- pipeline_stable_diffusion.py | 64 ++++++++++++++++++++++++------------ 3 files changed, 59 insertions(+), 24 deletions(-) diff --git a/archimulti_minisd.sh b/archimulti_minisd.sh index 12e1b2180..cd6a4a9eb 100755 --- a/archimulti_minisd.sh +++ b/archimulti_minisd.sh @@ -2,7 +2,7 @@ for u in `seq 25` do -export prompt="A photographic portrait of a young woman, tilted head, from below, red hair, green eyes, cleavage, smiling." +export prompt="A photographic portrait of a young woman, tilted head, from below, red hair, green eyes, smiling." python minisd.py export prompt="A photo of a cute armoured bloody red panda fighting off tentacles with daggers." python minisd.py diff --git a/minisd.py b/minisd.py index d1443d8c6..8f0675ca7 100644 --- a/minisd.py +++ b/minisd.py @@ -30,13 +30,26 @@ name = "Zendaya" prompt = f"Photo of {name} as a sumo-tori." -prompt = "A close up photographic portrait of a young woman with colored hair." +prompt = "Full length portrait of Mark Zuckerberg as a Sumo-Tori." +prompt = "Full length portrait of Scarlett Johansson as a Sumo-Tori." +prompt = "A close up photographic portrait of a young woman with uniformly colored hair." +prompt = "Zombies raising and worshipping a flying human." +prompt = "Zombies trying to kill Meg Myers." +prompt = "Meg Myers with an Egyptian dress killing a vampire with a gun." +prompt = "Meg Myers grabbing a vampire by the scruff of the neck." +prompt = "Mark Zuckerberg chokes a vampire to death." +prompt = "Mark Zuckerberg riding an animal." +prompt = "A giant cute animal worshipped by zombies." + + + +import os +prompt = os.environ.get("prompt", prompt) with autocast("cuda"): image = pipe(prompt, guidance_scale=7.5)["sample"][0] sentinel = random.randint(0,100000) image.save(f"SD_{prompt.replace(' ','_')}_image_{sentinel}.png") -import os latent = eval((os.environ["latent_sd"])) with open(f"SD_{prompt.replace(' ','_')}_latent_{sentinel}.txt", 'w') as f: f.write(f"{latent}") diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index 83aebf928..bbc5a6b41 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -1,5 +1,6 @@ import inspect import os +import numpy as np import random import warnings from typing import List, Optional, Union @@ -207,9 +208,10 @@ def __call__( # However this currently doesn't work in `mps`. latents_device = "cpu" if self.device.type == "mps" else self.device latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + latents_intermediate_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) if latents is None: latents = torch.randn( - latents_shape, + latents_intermediate_shape, generator=generator, device=latents_device, ) @@ -217,20 +219,22 @@ def __call__( from goodbad import good from goodbad import bad i_believe_in_evolution = len(good) > 0 and len(bad) > 0 + #i_believe_in_evolution = False print(f"I believe in evolution = {i_believe_in_evolution}") if i_believe_in_evolution: from sklearn import tree from sklearn.neural_network import MLPClassifier #from sklearn.neighbors import NearestCentroid from sklearn.linear_model import LogisticRegression - import numpy as np #z = (np.random.randn(4*64*64)) z = latents.cpu().numpy().flatten() - - #clf=tree.DecisionTreeClassifier()#min_samples_split=0.1) - clf = MLPClassifier(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(5, 2), random_state=1) + if os.environ.get("skl", "tree") == "tree": + clf = tree.DecisionTreeClassifier()#min_samples_split=0.1) + elif os.environ.get("skl", "tree") == "logit": + clf = LogisticRegression() + else: + clf = MLPClassifier(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(5, 2), random_state=1) #clf = NearestCentroid() - #clf = LogisticRegression() @@ -241,7 +245,8 @@ def __call__( epsilon = 1.0 def loss(x): - return clf.predict_proba([(1-epsilon)*z+epsilon*x])[0][0] # for astronauts + return clf.predict_proba([x])[0][0] # for astronauts + #return clf.predict_proba([(1-epsilon)*z+epsilon*x])[0][0] # for astronauts #return clf.predict_proba([z+epsilon*x])[0][0] @@ -250,7 +255,9 @@ def loss(x): budget = 300 #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) - nevergrad_optimizer = ng.optimizers.DiscreteLenglerOnePlusOne(len(z), budget) + optim_class = ng.optimizers.registry[os.environ.get("ngoptim", "DiscreteLenglerOnePlusOne")] + #nevergrad_optimizer = ng.optimizers.DiscreteLenglerOnePlusOne(len(z), budget) + nevergrad_optimizer = optim_class(len(z), budget) #nevergrad_optimizer = ng.optimizers.DiscreteOnePlusOne(len(z), budget) for k in range(5): z1 = np.array(random.choice(good)) @@ -258,29 +265,44 @@ def loss(x): z3 = np.array(random.choice(good)) z4 = np.array(random.choice(good)) z5 = np.array(random.choice(good)) - z = z1 - for u in range(len(z)): - z[u] = random.choice([z1[u],z2[u],z3[u],z4[u],z5[u]]) + #z = 0.99 * z1 + 0.01 * (z2+z3+z4+z5)/4. + z = 0.2 * (z1 + z2 + z3 + z4 + z5) + mu = int(os.environ.get("mu", "5")) + parents = [z1, z2, z3, z4, z5] + weights = [np.exp(np.random.randn() - i * float(os.environ.get("decay", "1."))) for i in range(5)] + z = weights[0] * z1 + for u in range(mu): + if u > 0: + z += weights[u] * parents[u] + z = (1. / sum(weights[:mu])) * z + z = np.sqrt(len(z)) * z / np.linalg.norm(z) + + #for u in range(len(z)): + # z[u] = random.choice([z1[u],z2[u],z3[u],z4[u],z5[u]]) nevergrad_optimizer.suggest(z) for i in range(budget): x = nevergrad_optimizer.ask() - l = loss(x.value) + z = x.value * np.sqrt(len(x.value)) / np.linalg.norm(x.value) + l = loss(z) nevergrad_optimizer.tell(x, l) if np.log2(i+1) == int(np.log2(i+1)): print(f"iteration {i} --> {l}") - print("var/variable = ", sum(((1-epsilon)*z + epsilon * x.value)**2)/len(x.value)) - x = nevergrad_optimizer.recommend().value - z = (1.-epsilon) * z + epsilon * x - #if l < 0.0000001: - # print(f"we find proba(bad)={l}") - # break - latents = torch.from_numpy(z.reshape(latents_shape)).float() #.half() + print("var/variable = ", sum(z**2)/len(z)) + #z = (1.-epsilon) * z + epsilon * x / np.sqrt(np.sum(x ** 2)) + if l < 0.0000001 and os.environ.get("earlystop") in ["true", "True"]: + print(f"we find proba(bad)={l}") + break + x = nevergrad_optimizer.recommend().value + z = x * np.sqrt(len(x)) / np.linalg.norm(x) + latents = torch.from_numpy(z.reshape(latents_intermediate_shape)).float() #.half() else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + if latents.shape != latents_intermediate_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_intermediate_shape}") os.environ["latent_sd"] = str(list(latents.flatten().cpu().numpy())) + for i in [2, 3]: + latents = torch.repeat_interleave(latents, repeats=latents_shape[i] // latents_intermediate_shape[i], dim=i) #/ np.sqrt(np.sqrt(latents_shape[i] // latents_intermediate_shape[i])) latents = latents.to(self.device) # set timesteps From 4cf48321b7429e06f5756e64e930e7d0397e8199 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 22 Sep 2022 16:47:55 +0200 Subject: [PATCH 14/76] fix --- pipeline_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index bbc5a6b41..1ff3f863d 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -291,7 +291,7 @@ def loss(x): print(f"iteration {i} --> {l}") print("var/variable = ", sum(z**2)/len(z)) #z = (1.-epsilon) * z + epsilon * x / np.sqrt(np.sum(x ** 2)) - if l < 0.0000001 and os.environ.get("earlystop") in ["true", "True"]: + if l < 0.0000001 and os.environ.get("earlystop", "True") in ["true", "True"]: print(f"we find proba(bad)={l}") break x = nevergrad_optimizer.recommend().value From 82253b9aadc459afa289c6be3733a8af38d48eeb Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 22 Sep 2022 16:55:14 +0200 Subject: [PATCH 15/76] fix --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index f5a0e44c1..052121b9d 100644 --- a/README.md +++ b/README.md @@ -3,10 +3,13 @@ Install as usual (below). Then use the "pipeline_stable_diffusion.py" in lieu of the original pipeline_stable_diffusion.py found at << python -c "import diffusers ; print(diffusers.__file__)" >>. -Then edit the prompt in minisd.py. +Then edit the prompt in multi_minisd.sh. Then ./multi_minisd.sh +and follow requests. +The code is not very user-friendly, if there are users I'll do better. + From ab7daf582c8de5ceb68c1e862f860816dd39ac3e Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 22 Sep 2022 16:59:31 +0200 Subject: [PATCH 16/76] fix --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 052121b9d..ba8bd860c 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,14 @@ Then and follow requests. The code is not very user-friendly, if there are users I'll do better. +Scripts: +- archimulti_minisd.sh runs a selection of prompts, for testing purpose. +- edit.sh will not work on your install: this is a convenience script for editing the code inside the conda environment, you have to update the path. +- inoculate_evo_sd.sh same as multi_minisd.sh, but not from scratch. You have to check the code for understanding, or ping me. +- minisd.sh runs stable diffusion, +- multi_minisd.sh main script. Run stable diffusion multiple times, and asks for your feedback. +- multiminisd.sh: run plenty of tests with various genetic methods. Used for tuning. Takes forever to run. +- view_history.sh: view what is in progress and put the last generated images in a single output.png From ec1a0803f5439c8e275d82cd895e497f2d48b10f Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 22 Sep 2022 17:09:53 +0200 Subject: [PATCH 17/76] fix --- multi_minisd.sh | 9 +++++++++ multiminisd.sh | 9 ++++++--- pipeline_stable_diffusion.py | 10 ++++++---- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/multi_minisd.sh b/multi_minisd.sh index a4b312a3e..37addf3b1 100755 --- a/multi_minisd.sh +++ b/multi_minisd.sh @@ -4,7 +4,16 @@ set -e #export prompt="A woman with many eyes." export prompt="A close up photographic portrait of a young woman with uniformly colored hair." +# We generate groups of lambda images. lambda=14 +# 4 parents are selected. +export mu=4 +# Do we use early stopping in Nevergrad ? +export earlystop=True +# Which Scikit-Learn surrogate model ? +export skl=tree +# Which initial range for modifications ? +export epsilon=0.0001 touch empty_file rm empty_file diff --git a/multiminisd.sh b/multiminisd.sh index cc5a283a5..726a8fff9 100755 --- a/multiminisd.sh +++ b/multiminisd.sh @@ -15,13 +15,15 @@ export ngoptim=$ngo for sl in tree nn logit do export skl=$sl -for es in True False +for es in False True do export earlystop=$es - +for eps in 0.0001 +do +export epsilon=$eps export prompt="A close up photographic portrait of a young woman with uniformly colored hair." -directory=biased_rw_experiment${numimages}_images_${mu}_${ngoptim}_${earlystop}_${skl}_${decay} +directory=biased_${epsilon}_rw_experiment${numimages}_images_${mu}_${ngoptim}_${earlystop}_${skl}_${decay} mkdir $directory for u in `seq $numimages` do @@ -38,3 +40,4 @@ done done done done +done diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index 1ff3f863d..4ba3fdd56 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -281,21 +281,23 @@ def loss(x): # z[u] = random.choice([z1[u],z2[u],z3[u],z4[u],z5[u]]) nevergrad_optimizer.suggest(z) - + z0 = z for i in range(budget): x = nevergrad_optimizer.ask() - z = x.value * np.sqrt(len(x.value)) / np.linalg.norm(x.value) + z = z0 + float(os.environ.get("epsilon", "0.001")) * x.value + z = np.sqrt(len(z)) * z / np.linalg.norm(z) l = loss(z) nevergrad_optimizer.tell(x, l) if np.log2(i+1) == int(np.log2(i+1)): print(f"iteration {i} --> {l}") print("var/variable = ", sum(z**2)/len(z)) #z = (1.-epsilon) * z + epsilon * x / np.sqrt(np.sum(x ** 2)) - if l < 0.0000001 and os.environ.get("earlystop", "True") in ["true", "True"]: + if l < 0.0000001 and os.environ.get("earlystop", "False") in ["true", "True"]: print(f"we find proba(bad)={l}") break x = nevergrad_optimizer.recommend().value - z = x * np.sqrt(len(x)) / np.linalg.norm(x) + z = z0 + float(os.environ.get("epsilon", "0.001")) * x + z = np.sqrt(len(z)) * z / np.linalg.norm(z) latents = torch.from_numpy(z.reshape(latents_intermediate_shape)).float() #.half() else: if latents.shape != latents_intermediate_shape: From 5702317418366de74c07df05d62df7e611551bd2 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Fri, 23 Sep 2022 13:19:39 +0200 Subject: [PATCH 18/76] fix --- multi_minisd.sh | 14 ++++++++++---- view_history.sh | 5 ++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/multi_minisd.sh b/multi_minisd.sh index 37addf3b1..6e9675e1a 100755 --- a/multi_minisd.sh +++ b/multi_minisd.sh @@ -3,17 +3,22 @@ set -e #export prompt="A woman with many eyes." -export prompt="A close up photographic portrait of a young woman with uniformly colored hair." +#export prompt="A close up photographic portrait of a young woman with uniformly colored hair." +#export prompt="Yann Lecun with a bloody armor and a sword, fighting tentacles in the jungle." +#export prompt="Many tentacles attacking Yann Lecun, equipped with a bloody armor and a sword." +#export prompt="An_armored_Mark_Zuckerberg_fighting_off_bloody_tentacles_in_the_jungle." +#export prompt="A cute monster in a city." +export prompt="A scary woman with tatoos, many arms, many weapons, flashy hair, sitting on a throne." # We generate groups of lambda images. -lambda=14 +lambda=20 # 4 parents are selected. export mu=4 # Do we use early stopping in Nevergrad ? -export earlystop=True +export earlystop=False # Which Scikit-Learn surrogate model ? export skl=tree # Which initial range for modifications ? -export epsilon=0.0001 +export epsilon=0.01 touch empty_file rm empty_file @@ -52,6 +57,7 @@ do do echo "generating image $kk / $lambda" python minisd.py + ./view_history.sh done list_of_four_images="`ls -ctr SD*_image_*.png | tail -n $lambda`" # my_new_list="" diff --git a/view_history.sh b/view_history.sh index ac81b97b6..bc2e7d8ac 100755 --- a/view_history.sh +++ b/view_history.sh @@ -1,2 +1,5 @@ -montage `ls -ctr SD*imag*.png | tail -n 18` -mode concatenate -tile 4x history.png +#montage `ls -ctr SD*imag*.png | head -n 15 | tail -n 14` -mode concatenate -tile 7x zuck1.png +#montage `ls -ctr SD*imag*.png | head -n 29 | tail -n 14` -mode concatenate -tile 7x zuck2.png +montage `ls -ctr SD*imag*.png | tail -n 28` -mode concatenate -tile 7x history.png open history.png +#cp history.png zuck3.png From 8480363b8b9024c93bfbe276ce1555ad59c4f57e Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Fri, 23 Sep 2022 15:19:53 +0200 Subject: [PATCH 19/76] voronoi_start --- README.md | 34 ++++----- archimulti_minisd.sh | 16 +--- edit.sh | 1 - inoculate_evo_sd.sh | 131 +------------------------------- maketgz.sh | 3 +- minisd.py | 126 +++++++++++++++++++++++++++++-- minisd.sh | 11 +-- multi_minisd.sh | 140 +---------------------------------- multiminisd.sh | 42 +---------- pipeline_stable_diffusion.py | 59 ++++++++------- 10 files changed, 172 insertions(+), 391 deletions(-) diff --git a/README.md b/README.md index ba8bd860c..6f0dc5804 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,22 @@ # Modified evolutionary version. -Install as usual (below). -Then use the "pipeline_stable_diffusion.py" in lieu of the original pipeline_stable_diffusion.py found at -<< python -c "import diffusers ; print(diffusers.__file__)" >>. -Then edit the prompt in multi_minisd.sh. -Then - ./multi_minisd.sh - -and follow requests. -The code is not very user-friendly, if there are users I'll do better. - -Scripts: -- archimulti_minisd.sh runs a selection of prompts, for testing purpose. -- edit.sh will not work on your install: this is a convenience script for editing the code inside the conda environment, you have to update the path. -- inoculate_evo_sd.sh same as multi_minisd.sh, but not from scratch. You have to check the code for understanding, or ping me. -- minisd.sh runs stable diffusion, -- multi_minisd.sh main script. Run stable diffusion multiple times, and asks for your feedback. -- multiminisd.sh: run plenty of tests with various genetic methods. Used for tuning. Takes forever to run. -- view_history.sh: view what is in progress and put the last generated images in a single output.png +1. Install StableDiffusion as usual, plus a few more stuff. Basically: +conda env create -f environment.yaml +conda activate ldm # you can change that name in the environment.yaml file... +conda install pytorch torchvision -c pytorch +pip install transformers diffusers invisible-watermark +pip install pygame +pip install -e . + + +2. Then use the file "pipeline_stable_diffusion.py" in lieu of the original pipeline_stable_diffusion.py found at +<< python -c "import diffusers ; print(diffusers.__file__)" >>. This is done as follows: +cp pipeline_stable_diffusion.py <>/pipeline_stable_diffusion.py + +3. Then edit the prompt in minisd.py, and possibly other variables. + +Then run << python minisd.py >>. # Stable Diffusion diff --git a/archimulti_minisd.sh b/archimulti_minisd.sh index cd6a4a9eb..e730b58d7 100755 --- a/archimulti_minisd.sh +++ b/archimulti_minisd.sh @@ -1,15 +1 @@ -#!/bin/bash - -for u in `seq 25` -do -export prompt="A photographic portrait of a young woman, tilted head, from below, red hair, green eyes, smiling." -python minisd.py -export prompt="A photo of a cute armoured bloody red panda fighting off tentacles with daggers." -python minisd.py -export prompt="A photo of a woman fighting off tentacles with guns." -python minisd.py -export prompt="A cute armoured red panda fighting off zombies with karate." -python minisd.py -export prompt="An armored Mark Zuckerberg fighting off bloody tentacles in the jungle." -python minisd.py -done +echo deprecated. diff --git a/edit.sh b/edit.sh index b8c49801b..325be12b4 100755 --- a/edit.sh +++ b/edit.sh @@ -1,3 +1,2 @@ -#vim /private/home/oteytaud/.conda/envs/sd/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py vim /opt/miniconda3/envs/sd_gpu/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py cp /opt/miniconda3/envs/sd_gpu/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py . diff --git a/inoculate_evo_sd.sh b/inoculate_evo_sd.sh index 256995f68..26919fa5e 100755 --- a/inoculate_evo_sd.sh +++ b/inoculate_evo_sd.sh @@ -1,131 +1,2 @@ #!/bin/bash - -echo Parametrization and initialization. -#export prompt="A close up photographic portrait of a young woman with uniformly colored hair." -export prompt="An armored Mark Zuckerberg fighting off bloody tentacles in the jungle." -lambda=18 -cp basic_inoculation_uniformly/SD*.* inoculations/ - - -set -e - -touch empty_file -rm empty_file -touch empty_file -touch SD_prout_${random}.png -touch SD_prout_${random}.txt -mv SD_*.png poubelle/ -mv SD_*.txt poubelle/ -# Initialization: run SD and create an image, with rank 1. -touch goodbad.py -rm goodbad.py -touch goodbad.py -echo "good = []" >> goodbad.py -echo "bad = []" >> goodbad.py -#python minisd.py -./minisd.sh - #sentinel=${RANDOM} - #touch SD_image_${sentinel}.png - #touch SD_latent_${sentinel}.txt -mylist="`ls -ctr SD*_image_*.png | tail -n 1`" -myranks=1 - -for i in `seq 30` -do - # Now an iteration. - echo Current images = $mylist - echo Current ranks = $myranks - #sentinel=${RANDOM} - #touch SD_image_${sentinel}.png - #touch SD_latent_${sentinel}.txt - echo "GENERATING $lambda IMAGES ================================" - cat goodbad.py | awk '!x[$0]++' > goodbad2.py - mv goodbad2.py goodbad.py - echo "`grep -c 'good +=' goodbad.py` positive examples" - echo "`grep -c 'bad +=' goodbad.py` negative examples" - for kk in `seq $lambda` - do - echo "generating image $kk / $lambda" - #python minisd.py - ./minisd.sh - done - list_of_four_images="`ls -ctr SD*_image_*.png | tail -n $lambda`" -# my_new_list="" -# my_new_ranks="" -# # We stop at 19 so that it becomes 20 with the new one -# for k in `seq 19` -# do -# my_new_list="$my_new_list `echo $mylist | cut -d ' ' -f $k`" -# my_new_ranks="$my_new_ranks `echo $myranks | cut -d ' ' -f $k`" -# done -# mylist=`echo $my_new_list | sed 's/[ ]*$//g'` -# myranks=`echo $my_new_ranks | sed 's/[ ]*$//g'` -# echo "After limiting to 19, we get $mylist and $myranks " - for img in $list_of_four_images - do - echo We add image $img ======================= - montage $mylist $img -mode Concatenate -tile 5x output.png - open --wait output.png - # read -t 1 prout - read -p "Rank of the last image ?" rank - echo "Provided rank: $rank" - mylist="$mylist $img" - if [[ $rank -le 0 ]] - then - read -p "Enter all ranks !!!!" myranks - else - mynewranks="" - for r in $myranks - do - [[ $r -ge $rank ]] && r=$(( $r + 1 )) - mynewranks="$mynewranks $r" - done - myranks="$mynewranks $rank" - fi - #echo Before sorting =========================== - #echo $myranks - #echo $mylist - #sleep 5 - - # Now sorting - mynewlist="" - mynewranks="" - sed -i.backup 's/good +=.*//g' goodbad.py - num_good=`cat goodbad.py | grep 'good +=' | wc -l ` - echo "Num goods in file: $num_good" - num_good=$(( $num_good / 2 + 5 )) - echo "Num goods after update: $num_good" - echo "We keep the $num_good best." - for r in `seq 20` - do - for k in `seq 20` - do - [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] && echo "FOUND $k for $r!" - my_image="`echo $mylist | cut -d ' ' -f $k`" - [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] && mynewranks="$mynewranks `echo $myranks | cut -d ' ' -f $k`" - [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] && mynewlist="$mynewlist `echo $mylist | cut -d ' ' -f $k`" - if [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] - then - echo Found $my_image at rank $k for $r - if [[ $r -le $num_good ]] - then - cat empty_file `echo $my_image | sed 's/image_[0-9]*.png/PROUTPROUT&/g' | sed 's/PROUTPROUTimage/latent/g' | sed 's/\.png/\.txt/g'` | sed "s/.*/good += [&]/g" >> goodbad.py - else - cat empty_file `echo $my_image | sed 's/image_[0-9]*.png/PROUTPROUT&/g' | sed 's/PROUTPROUTimage/latent/g' | sed 's/\.png/\.txt/g'` | sed "s/.*/bad += [&]/g" >> goodbad.py - fi - echo "" >> goodbad.py - break - fi - done - done - myranks=$mynewranks - mylist=$mynewlist - done - echo After sorting =========================== - echo $myranks - echo $mylist - #sleep 2 - -done - - +echo deprecated diff --git a/maketgz.sh b/maketgz.sh index 15e54d619..d0b91883a 100755 --- a/maketgz.sh +++ b/maketgz.sh @@ -1,2 +1 @@ -tar -zcvf ~/bigpack2.tgz `ls -ctrl | grep 'Sep.16' | grep '\.png' | sed 's/.* //g'` | wc -l -ls -ctlr ~/bigpack2.tgz +echo deprecated diff --git a/minisd.py b/minisd.py index 8f0675ca7..593d2c09b 100644 --- a/minisd.py +++ b/minisd.py @@ -1,5 +1,7 @@ import random +import os import torch +import numpy as np from torch import autocast from diffusers import StableDiffusionPipeline @@ -7,6 +9,13 @@ #device = "cuda" device = "mps" #torch.device("mps") +os.environ["skl"] = "nn" +os.environ["epsilon"] = "0.005" +os.environ["decay"] = "0." +os.environ["ngoptim"] = "DiscreteLenglerOnePlusOne" +os.environ["forcedlatent"] = "" +os.environ["good"] = "[]" +os.environ["bad"] = "[]" pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token="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") pipe = pipe.to(device) @@ -44,13 +53,114 @@ import os -prompt = os.environ.get("prompt", prompt) -with autocast("cuda"): - image = pipe(prompt, guidance_scale=7.5)["sample"][0] +import pygame +from os import listdir +from os.path import isfile, join -sentinel = random.randint(0,100000) -image.save(f"SD_{prompt.replace(' ','_')}_image_{sentinel}.png") -latent = eval((os.environ["latent_sd"])) -with open(f"SD_{prompt.replace(' ','_')}_latent_{sentinel}.txt", 'w') as f: - f.write(f"{latent}") +sentinel = str(random.randint(0,100000)) + "XX" + str(random.randint(0,100000)) + +all_files = [] + +llambda = 2 + +assert llambda < 16, "lambda < 16 for convenience in pygame." +bad = [] +for iteration in range(30): + onlyfiles = [] + latent = [] + for k in range(llambda): + os.environ["earlystop"] = "False" if k > 0 else "True" + os.environ["epsilon"] = str(0. if k == 0 else 0.1 / k) + os.environ["budget"] = str(300 if k > 0 else 3) + with autocast("cuda"): + image = pipe(prompt, guidance_scale=7.5)["sample"][0] + filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{k}.png" + image.save(filename) + onlyfiles += [filename] + str_latent = eval((os.environ["latent_sd"])) + array_latent = eval(f"np.array(str_latent).reshape(4, 64, 64)") + print(f"array_latent sumsq/var {sum(array_latent.flatten() ** 2) / len(array_latent.flatten())}") + latent += [array_latent] + with open(f"SD_{prompt.replace(' ','_')}_latent_{sentinel}_{k}.txt", 'w') as f: + f.write(f"{latent}") + + # importing required library + + #mypath = "./" + #onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))] + #onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] + #print() + # activate the pygame library . + pygame.init() + X = 1500 + Y = 900 + + # create the display surface object + # of specific dimension..e(X, Y). + scrn = pygame.display.set_mode((X, Y)) + + for idx in range(min(15, len(onlyfiles))): + # set the pygame window name + pygame.display.set_caption('images') + + # create a surface object, image is drawn on it. + imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300)) + + # Using blit to copy content from one surface to other + scrn.blit(imp, (300 * (idx // 3), 300 * (idx % 3))) + + # paint screen one time + pygame.display.flip() + status = True + indices = [] + good = [] + while (status): + + # iterate over the list of Event objects + # that was returned by pygame.event.get() method. + for i in pygame.event.get(): + if i.type == pygame.MOUSEBUTTONUP: + pos = pygame.mouse.get_pos() + print(pos) + index = 3 * (pos[0] // 300) + (pos[1] // 300) + indices += [[index, (pos[0] - (pos[0] // 300) * 300) / 300, (pos[1] - (pos[1] // 300) * 300) / 300]] + good += [list(latent[index].flatten())] + + # if event object type is QUIT + # then quitting the pygame + # and program both. + if i.type == pygame.QUIT: + status = False + + # deactivates the pygame library + pygame.quit() + print(indices) + os.environ["mu"] = str(len(indices)) + forcedlatent = np.zeros((4, 64, 64)) + bad += [list(latent[u].flatten()) for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]] + os.environ["good"] = str(good) + os.environ["bad"] = str(bad) + for i in range(64): + x = i / 63. + for j in range(64): + y = j / 63 + mindistances = 10000000000. + for u in range(len(indices)): + distance = np.linalg.norm( np.array((x, y)) - np.array((indices[u][1], indices[u][2])) ) + if distance < mindistances: + mindistances = distance + uu = indices[u][0] + for k in range(4): + assert k < len(forcedlatent), k + assert i < len(forcedlatent[k]), i + assert j < len(forcedlatent[k][i]), j + assert uu < len(latent) + assert k < len(latent[uu]), k + assert i < len(latent[uu][k]), i + assert j < len(latent[uu][k][i]), j + forcedlatent[k][i][j] = latent[uu][k][i][j] + os.environ["forcedlatent"] = str(list(forcedlatent.flatten())) + #for uu in range(len(latent)): + # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}") + diff --git a/minisd.sh b/minisd.sh index f5f3da39d..26919fa5e 100755 --- a/minisd.sh +++ b/minisd.sh @@ -1,11 +1,2 @@ #!/bin/bash -if compgen -G inoculations/SD*.png > /dev/null ; then -file=`ls -ctr inoculations/SD*.png | tail -n 1` -echo Transformation: -latent=`echo $file | sed 's/image_[0-9]*.png/PROUTPROUT&/g' | sed 's/PROUTPROUTimage/latent/g' | sed 's/\.png/\.txt/g'` -echo Image=$file -echo Latent=$latent -mv $file $latent . -else -python minisd.py -fi +echo deprecated diff --git a/multi_minisd.sh b/multi_minisd.sh index 6e9675e1a..8699b6cc4 100755 --- a/multi_minisd.sh +++ b/multi_minisd.sh @@ -1,141 +1,3 @@ #!/bin/bash -set -e - -#export prompt="A woman with many eyes." -#export prompt="A close up photographic portrait of a young woman with uniformly colored hair." -#export prompt="Yann Lecun with a bloody armor and a sword, fighting tentacles in the jungle." -#export prompt="Many tentacles attacking Yann Lecun, equipped with a bloody armor and a sword." -#export prompt="An_armored_Mark_Zuckerberg_fighting_off_bloody_tentacles_in_the_jungle." -#export prompt="A cute monster in a city." -export prompt="A scary woman with tatoos, many arms, many weapons, flashy hair, sitting on a throne." -# We generate groups of lambda images. -lambda=20 -# 4 parents are selected. -export mu=4 -# Do we use early stopping in Nevergrad ? -export earlystop=False -# Which Scikit-Learn surrogate model ? -export skl=tree -# Which initial range for modifications ? -export epsilon=0.01 - -touch empty_file -rm empty_file -touch empty_file -touch SD_prout_${random}.png -touch SD_prout_${random}.txt -mv SD_*.png poubelle/ -mv SD_*.txt poubelle/ -# Initialization: run SD and create an image, with rank 1. -touch goodbad.py -rm goodbad.py -touch goodbad.py -echo "good = []" >> goodbad.py -echo "bad = []" >> goodbad.py -python minisd.py - #sentinel=${RANDOM} - #touch SD_image_${sentinel}.png - #touch SD_latent_${sentinel}.txt -mylist="`ls -ctr SD*_image_*.png | tail -n 1`" -myranks=1 - -for i in `seq 30` -do - # Now an iteration. - echo Current images = $mylist - echo Current ranks = $myranks - #sentinel=${RANDOM} - #touch SD_image_${sentinel}.png - #touch SD_latent_${sentinel}.txt - echo "GENERATING $lambda IMAGES ================================" - cat goodbad.py | awk '!x[$0]++' > goodbad2.py - mv goodbad2.py goodbad.py - echo "`grep -c 'good +=' goodbad.py` positive examples" - echo "`grep -c 'bad +=' goodbad.py` negative examples" - for kk in `seq $lambda` - do - echo "generating image $kk / $lambda" - python minisd.py - ./view_history.sh - done - list_of_four_images="`ls -ctr SD*_image_*.png | tail -n $lambda`" -# my_new_list="" -# my_new_ranks="" -# # We stop at 19 so that it becomes 20 with the new one -# for k in `seq 19` -# do -# my_new_list="$my_new_list `echo $mylist | cut -d ' ' -f $k`" -# my_new_ranks="$my_new_ranks `echo $myranks | cut -d ' ' -f $k`" -# done -# mylist=`echo $my_new_list | sed 's/[ ]*$//g'` -# myranks=`echo $my_new_ranks | sed 's/[ ]*$//g'` -# echo "After limiting to 19, we get $mylist and $myranks " - for img in $list_of_four_images - do - echo We add image $img ======================= - montage $mylist $img -mode Concatenate -tile 5x output.png - open --wait output.png - # read -t 1 prout - read -p "Rank of the last image ?" rank - echo "Provided rank: $rank" - mylist="$mylist $img" - if [[ $rank -le 0 ]] - then - read -p "Enter all ranks !!!!" myranks - else - mynewranks="" - for r in $myranks - do - [[ $r -ge $rank ]] && r=$(( $r + 1 )) - mynewranks="$mynewranks $r" - done - myranks="$mynewranks $rank" - fi - #echo Before sorting =========================== - #echo $myranks - #echo $mylist - #sleep 5 - - # Now sorting - mynewlist="" - mynewranks="" - sed -i.backup 's/good +=.*//g' goodbad.py - num_good=`cat goodbad.py | grep 'good +=' | wc -l ` - echo "Num goods in file: $num_good" - num_good=$(( $num_good / 2 + 5 )) - echo "Num goods after update: $num_good" - echo "We keep the $num_good best." - for r in `seq 20` - do - for k in `seq 20` - do - [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] && echo "FOUND $k for $r!" - my_image="`echo $mylist | cut -d ' ' -f $k`" - [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] && mynewranks="$mynewranks `echo $myranks | cut -d ' ' -f $k`" - [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] && mynewlist="$mynewlist `echo $mylist | cut -d ' ' -f $k`" - if [[ `echo $myranks | cut -d ' ' -f $k` -eq $r ]] - then - echo Found $my_image at rank $k for $r - if [[ $r -le $num_good ]] - then - cat empty_file `echo $my_image | sed 's/image_[0-9]*.png/PROUTPROUT&/g' | sed 's/PROUTPROUTimage/latent/g' | sed 's/\.png/\.txt/g'` | sed "s/.*/good += [&]/g" >> goodbad.py - else - cat empty_file `echo $my_image | sed 's/image_[0-9]*.png/PROUTPROUT&/g' | sed 's/PROUTPROUTimage/latent/g' | sed 's/\.png/\.txt/g'` | sed "s/.*/bad += [&]/g" >> goodbad.py - fi - echo "" >> goodbad.py - break - fi - done - done - myranks=$mynewranks - mylist=$mynewlist - done - echo After sorting =========================== - echo $myranks - echo $mylist - #sleep 2 - -done - - +echo deprecated diff --git a/multiminisd.sh b/multiminisd.sh index 726a8fff9..8699b6cc4 100755 --- a/multiminisd.sh +++ b/multiminisd.sh @@ -1,43 +1,3 @@ #!/bin/bash -touch SD.prout.${RANDOM} -mv SD*.* poubelle/ -numimages=12 -for m in 5 #2 5 3 4 1 -do -export mu=$m -for d in 1 0.5 0 -do -export decay=$d -for ngo in OnePlusOne DiscreteOnePlusOne RandomSearch DiscreteLenglerOnePlusOne -do -export ngoptim=$ngo -for sl in tree nn logit -do -export skl=$sl -for es in False True -do -export earlystop=$es -for eps in 0.0001 -do -export epsilon=$eps - -export prompt="A close up photographic portrait of a young woman with uniformly colored hair." -directory=biased_${epsilon}_rw_experiment${numimages}_images_${mu}_${ngoptim}_${earlystop}_${skl}_${decay} -mkdir $directory -for u in `seq $numimages` -do -cp goodbad_learnbluehair.py goodbad.py -python minisd.py -./view_history.sh -sleep 1 -done -cp history.png SD* *.py *.sh $directory - -mv SD*.* poubelle/ -done -done -done -done -done -done +echo deprecated diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index 4ba3fdd56..225216ac7 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -215,9 +215,9 @@ def __call__( generator=generator, device=latents_device, ) - - from goodbad import good - from goodbad import bad + good = eval(os.environ["good"]) + bad = eval(os.environ["bad"]) + print(f"{len(good)} good and {len(bad)} bad") i_believe_in_evolution = len(good) > 0 and len(bad) > 0 #i_believe_in_evolution = False print(f"I believe in evolution = {i_believe_in_evolution}") @@ -252,36 +252,40 @@ def loss(x): if i_believe_in_evolution: import nevergrad as ng - budget = 300 + budget = int(os.environ.get("budget", "300")) #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) optim_class = ng.optimizers.registry[os.environ.get("ngoptim", "DiscreteLenglerOnePlusOne")] #nevergrad_optimizer = ng.optimizers.DiscreteLenglerOnePlusOne(len(z), budget) nevergrad_optimizer = optim_class(len(z), budget) #nevergrad_optimizer = ng.optimizers.DiscreteOnePlusOne(len(z), budget) - for k in range(5): - z1 = np.array(random.choice(good)) - z2 = np.array(random.choice(good)) - z3 = np.array(random.choice(good)) - z4 = np.array(random.choice(good)) - z5 = np.array(random.choice(good)) - #z = 0.99 * z1 + 0.01 * (z2+z3+z4+z5)/4. - z = 0.2 * (z1 + z2 + z3 + z4 + z5) - mu = int(os.environ.get("mu", "5")) - parents = [z1, z2, z3, z4, z5] - weights = [np.exp(np.random.randn() - i * float(os.environ.get("decay", "1."))) for i in range(5)] - z = weights[0] * z1 - for u in range(mu): - if u > 0: - z += weights[u] * parents[u] - z = (1. / sum(weights[:mu])) * z - z = np.sqrt(len(z)) * z / np.linalg.norm(z) - - #for u in range(len(z)): - # z[u] = random.choice([z1[u],z2[u],z3[u],z4[u],z5[u]]) - nevergrad_optimizer.suggest(z) - - z0 = z +# for k in range(5): +# z1 = np.array(random.choice(good)) +# z2 = np.array(random.choice(good)) +# z3 = np.array(random.choice(good)) +# z4 = np.array(random.choice(good)) +# z5 = np.array(random.choice(good)) +# #z = 0.99 * z1 + 0.01 * (z2+z3+z4+z5)/4. +# z = 0.2 * (z1 + z2 + z3 + z4 + z5) +# mu = int(os.environ.get("mu", "5")) +# parents = [z1, z2, z3, z4, z5] +# weights = [np.exp(np.random.randn() - i * float(os.environ.get("decay", "1."))) for i in range(5)] +# z = weights[0] * z1 +# for u in range(mu): +# if u > 0: +# z += weights[u] * parents[u] +# z = (1. / sum(weights[:mu])) * z +# z = np.sqrt(len(z)) * z / np.linalg.norm(z) +# +# #for u in range(len(z)): +# # z[u] = random.choice([z1[u],z2[u],z3[u],z4[u],z5[u]]) +# nevergrad_optimizer.suggest + if len(os.environ["forcedlatent"]) > 0: + print("we get a forcing for the latent z.") + z0 = eval(os.environ["forcedlatent"]) + #nevergrad_optimizer.suggest(eval(os.environ["forcedlatent"])) + else: + z0 = z for i in range(budget): x = nevergrad_optimizer.ask() z = z0 + float(os.environ.get("epsilon", "0.001")) * x.value @@ -302,6 +306,7 @@ def loss(x): else: if latents.shape != latents_intermediate_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_intermediate_shape}") + print(f"latent ==> {sum(latents.flatten()**2) / len(latents.flatten())}") os.environ["latent_sd"] = str(list(latents.flatten().cpu().numpy())) for i in [2, 3]: latents = torch.repeat_interleave(latents, repeats=latents_shape[i] // latents_intermediate_shape[i], dim=i) #/ np.sqrt(np.sqrt(latents_shape[i] // latents_intermediate_shape[i])) From d2ecd20eb922f90557d8df040a8bb1718254cd49 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Fri, 23 Sep 2022 21:24:53 +0200 Subject: [PATCH 20/76] fix --- README.md | 1 + minisd.py | 89 ++++++++++++++++++++++++++++++++----------------- view_history.sh | 4 ++- 3 files changed, 62 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 6f0dc5804..4297a663c 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ conda activate ldm # you can change that name in the environment.yaml file... conda install pytorch torchvision -c pytorch pip install transformers diffusers invisible-watermark pip install pygame +pip install nevergrad pip install -e . diff --git a/minisd.py b/minisd.py index 593d2c09b..d2e63bbbb 100644 --- a/minisd.py +++ b/minisd.py @@ -35,6 +35,7 @@ "Photo of Schwarzy as a ballet dancer", ]) + name = random.choice(["Mark Zuckerbeg", "Zendaya", "Yann LeCun", "Scarlett Johansson", "Superman", "Meg Myers"]) name = "Zendaya" prompt = f"Photo of {name} as a sumo-tori." @@ -51,6 +52,13 @@ prompt = "A giant cute animal worshipped by zombies." +prompt = "Several faces." + +prompt = "An armoured Yann LeCun fighting tentacles in the jungle." +prompt = "Tentacles everywhere." +prompt = "A photo of a smiling Medusa." +prompt = "Medusa." +prompt = "Meg Myers in bloody armor fending off tentacles with a sword." import os import pygame @@ -61,21 +69,31 @@ all_files = [] -llambda = 2 +llambda = 15 assert llambda < 16, "lambda < 16 for convenience in pygame." bad = [] +five_best = [] +latent = [] +images = [] for iteration in range(30): onlyfiles = [] - latent = [] + latent = [latent[f] for f in five_best] + images = [images[f] for f in five_best] for k in range(llambda): - os.environ["earlystop"] = "False" if k > 0 else "True" - os.environ["epsilon"] = str(0. if k == 0 else 0.1 / k) - os.environ["budget"] = str(300 if k > 0 else 3) + if k < len(five_best): + continue + os.environ["earlystop"] = "False" if k > len(five_best) else "True" + os.environ["epsilon"] = str(0. if k == len(five_best) else (k - len(five_best)) / llambda) + os.environ["budget"] = str(300 if k > len(five_best) else 2) + os.environ["skl"] = {0: "nn", 1: "tree", 2: "logit"}[k % 3] + if iteration > 0: + os.environ["forcedlatent"] = str(list(forcedlatents[k].flatten())) with autocast("cuda"): image = pipe(prompt, guidance_scale=7.5)["sample"][0] - filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{k}.png" + images += [image] + filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{iteration}_{k}.png" image.save(filename) onlyfiles += [filename] str_latent = eval((os.environ["latent_sd"])) @@ -100,12 +118,13 @@ # of specific dimension..e(X, Y). scrn = pygame.display.set_mode((X, Y)) - for idx in range(min(15, len(onlyfiles))): + for idx in range(llambda): # set the pygame window name pygame.display.set_caption('images') # create a surface object, image is drawn on it. - imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300)) + imp = pygame.transform.scale(images[idx].convert(), (300, 300)) + #imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300)) # Using blit to copy content from one surface to other scrn.blit(imp, (300 * (idx // 3), 300 * (idx % 3))) @@ -115,6 +134,7 @@ status = True indices = [] good = [] + five_best = [] while (status): # iterate over the list of Event objects @@ -124,6 +144,8 @@ pos = pygame.mouse.get_pos() print(pos) index = 3 * (pos[0] // 300) + (pos[1] // 300) + if index not in five_best and len(five_best) < 5: + five_best += [index] indices += [[index, (pos[0] - (pos[0] // 300) * 300) / 300, (pos[1] - (pos[1] // 300) * 300) / 300]] good += [list(latent[index].flatten())] @@ -137,30 +159,35 @@ pygame.quit() print(indices) os.environ["mu"] = str(len(indices)) - forcedlatent = np.zeros((4, 64, 64)) + forcedlatents = [] bad += [list(latent[u].flatten()) for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]] - os.environ["good"] = str(good) - os.environ["bad"] = str(bad) - for i in range(64): - x = i / 63. - for j in range(64): - y = j / 63 - mindistances = 10000000000. - for u in range(len(indices)): - distance = np.linalg.norm( np.array((x, y)) - np.array((indices[u][1], indices[u][2])) ) - if distance < mindistances: - mindistances = distance - uu = indices[u][0] - for k in range(4): - assert k < len(forcedlatent), k - assert i < len(forcedlatent[k]), i - assert j < len(forcedlatent[k][i]), j - assert uu < len(latent) - assert k < len(latent[uu]), k - assert i < len(latent[uu][k]), i - assert j < len(latent[uu][k][i]), j - forcedlatent[k][i][j] = latent[uu][k][i][j] - os.environ["forcedlatent"] = str(list(forcedlatent.flatten())) + for a in range(llambda): + forcedlatent = np.zeros((4, 64, 64)) + os.environ["good"] = str(good) + os.environ["bad"] = str(bad) + coefficients = np.zeros(len(indices)) + for i in range(len(indices)): + coefficients[i] = np.exp(np.random.randn()) + for i in range(64): + x = i / 63. + for j in range(64): + y = j / 63 + mindistances = 10000000000. + for u in range(len(indices)): + distance = coefficients[u] * np.linalg.norm( np.array((x, y)) - np.array((indices[u][1], indices[u][2])) ) + if distance < mindistances: + mindistances = distance + uu = indices[u][0] + for k in range(4): + assert k < len(forcedlatent), k + assert i < len(forcedlatent[k]), i + assert j < len(forcedlatent[k][i]), j + assert uu < len(latent) + assert k < len(latent[uu]), k + assert i < len(latent[uu][k]), i + assert j < len(latent[uu][k][i]), j + forcedlatent[k][i][j] = latent[uu][k][i][j] + forcedlatents += [forcedlatent] #for uu in range(len(latent)): # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}") diff --git a/view_history.sh b/view_history.sh index bc2e7d8ac..82cf3f4e5 100755 --- a/view_history.sh +++ b/view_history.sh @@ -1,5 +1,7 @@ +#!/bin/bash #montage `ls -ctr SD*imag*.png | head -n 15 | tail -n 14` -mode concatenate -tile 7x zuck1.png #montage `ls -ctr SD*imag*.png | head -n 29 | tail -n 14` -mode concatenate -tile 7x zuck2.png -montage `ls -ctr SD*imag*.png | tail -n 28` -mode concatenate -tile 7x history.png +#montage `ls -ctr SD*imag*.png | tail -n 28` -mode concatenate -tile 7x history.png +montage $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`*.png | sort | tail -n 15 | sort ) -mode concatenate -tile 5x history.png open history.png #cp history.png zuck3.png From 674aa8e428347b7e96a501a57782f208477b8aa9 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Mon, 26 Sep 2022 10:45:09 +0200 Subject: [PATCH 21/76] fix --- minisd.py | 111 ++++++++++++++++++++++++++++++++++++++++++------ view_history.sh | 3 +- 2 files changed, 100 insertions(+), 14 deletions(-) diff --git a/minisd.py b/minisd.py index d2e63bbbb..15c302b2d 100644 --- a/minisd.py +++ b/minisd.py @@ -4,16 +4,23 @@ import numpy as np from torch import autocast from diffusers import StableDiffusionPipeline +import webbrowser + model_id = "CompVis/stable-diffusion-v1-4" #device = "cuda" device = "mps" #torch.device("mps") +white = (255, 255, 255) +green = (0, 255, 0) +blue = (0, 0, 128) + os.environ["skl"] = "nn" os.environ["epsilon"] = "0.005" os.environ["decay"] = "0." os.environ["ngoptim"] = "DiscreteLenglerOnePlusOne" os.environ["forcedlatent"] = "" +os.environ["enforcedlatent"] = "" os.environ["good"] = "[]" os.environ["bad"] = "[]" @@ -59,6 +66,14 @@ prompt = "A photo of a smiling Medusa." prompt = "Medusa." prompt = "Meg Myers in bloody armor fending off tentacles with a sword." +prompt = "A red-haired woman with red hair. Her head is tilted." +prompt = "A bloody heavy-metal zombie with a chainsaw." +prompt = "Tentacles attacking a bloody Meg Myers in Eyptian dress. Meg Myers has a chainsaw." +prompt = "Bizarre art." +print(f"The prompt is {prompt}") +user_prompt = input("Enter a new prompt if you prefer\n") +if len(user_prompt) > 2: + prompt = user_prompt import os import pygame @@ -77,19 +92,35 @@ five_best = [] latent = [] images = [] +onlyfiles = [] + +# activate the pygame library . +pygame.init() +X = 1900 # > 1500 = buttons +Y = 900 +scrn = pygame.display.set_mode((X, Y)) +font = pygame.font.Font('freesansbold.ttf', 32) + + for iteration in range(30): - onlyfiles = [] latent = [latent[f] for f in five_best] images = [images[f] for f in five_best] + onlyfiles = [onlyfiles[f] for f in five_best] for k in range(llambda): if k < len(five_best): continue + text0 = font.render(f'Please wait !!! {k} / {llambda}', True, green, blue) + scrn.blit(text0, ((X*3/4)/2, Y/2)) + pygame.display.flip() os.environ["earlystop"] = "False" if k > len(five_best) else "True" os.environ["epsilon"] = str(0. if k == len(five_best) else (k - len(five_best)) / llambda) os.environ["budget"] = str(300 if k > len(five_best) else 2) os.environ["skl"] = {0: "nn", 1: "tree", 2: "logit"}[k % 3] if iteration > 0: os.environ["forcedlatent"] = str(list(forcedlatents[k].flatten())) + enforcedlatent = os.environ.get("enforcedlatent", "") + if len(enforcedlatent) > 2: + os.environ["forcedlatent"] = enforcedlatent with autocast("cuda"): image = pipe(prompt, guidance_scale=7.5)["sample"][0] images += [image] @@ -98,33 +129,53 @@ onlyfiles += [filename] str_latent = eval((os.environ["latent_sd"])) array_latent = eval(f"np.array(str_latent).reshape(4, 64, 64)") - print(f"array_latent sumsq/var {sum(array_latent.flatten() ** 2) / len(array_latent.flatten())}") + print(f"Debug info: array_latent sumsq/var {sum(array_latent.flatten() ** 2) / len(array_latent.flatten())}") latent += [array_latent] with open(f"SD_{prompt.replace(' ','_')}_latent_{sentinel}_{k}.txt", 'w') as f: f.write(f"{latent}") + # Stop the forcing from disk! + os.environ["enforcedlatent"] = "" # importing required library #mypath = "./" #onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))] #onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] #print() - # activate the pygame library . - pygame.init() - X = 1500 - Y = 900 # create the display surface object # of specific dimension..e(X, Y). - scrn = pygame.display.set_mode((X, Y)) - + + # Button for loading a starting point + text1 = font.render('Load image', True, green, blue) + text1 = pygame.transform.rotate(text1, 90) + scrn.blit(text1, (X*3/4, 0)) + text1 = font.render('& latent', True, green, blue) + text1 = pygame.transform.rotate(text1, 90) + scrn.blit(text1, (X*3/4+X/16, 0)) + # Button for creating a meme + text2 = font.render('Create', True, green, blue) + text2 = pygame.transform.rotate(text2, 90) + scrn.blit(text2, (X*3/4, Y/3)) + text2 = font.render('a meme', True, green, blue) + text2 = pygame.transform.rotate(text2, 90) + scrn.blit(text2, (X*3/4+X/16, Y/3)) + # Button for new generation + text3 = font.render(f"I don't want to", True, green, blue) + text3 = pygame.transform.rotate(text3, 90) + scrn.blit(text3, (X*3/4, Y*2/3)) + text3 = font.render(f"select images! Just rerun.", True, green, blue) + text3 = pygame.transform.rotate(text3, 90) + scrn.blit(text3, (X*3/4+X/16, Y*2/3)) + for idx in range(llambda): # set the pygame window name pygame.display.set_caption('images') # create a surface object, image is drawn on it. - imp = pygame.transform.scale(images[idx].convert(), (300, 300)) - #imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300)) + #imp = pygame.transform.scale(images[idx], (300, 300)) + #imp = pygame.transform.scale(images[idx].convert(), (300, 300)) # TypeError: argument 1 must be pygame.Surface, not Image + imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300)) # Using blit to copy content from one surface to other scrn.blit(imp, (300 * (idx // 3), 300 * (idx % 3))) @@ -135,6 +186,9 @@ indices = [] good = [] five_best = [] + for i in pygame.event.get(): + if i.type == pygame.MOUSEBUTTONUP: + print("too early for clicking !!!!") while (status): # iterate over the list of Event objects @@ -142,11 +196,37 @@ for i in pygame.event.get(): if i.type == pygame.MOUSEBUTTONUP: pos = pygame.mouse.get_pos() - print(pos) + print(f"Click at {pos}") + if pos[0] > 1500: # Not in the images. + if pos[1] < Y/3: + filename = input("Filename (please provide the latent file, of the format SD*latent*.txt) ?\n") + status = False + with open(filename, 'r') as f: + latent = f.read() + break + if pos[1] < 2*Y/3: + url = 'https://imgflip.com/memegenerator' + onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))] + onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] + print("Your generated images:") + print(onlyfiles) + webbrowser.open(url) + exit() + status = False + break index = 3 * (pos[0] // 300) + (pos[1] // 300) if index not in five_best and len(five_best) < 5: five_best += [index] indices += [[index, (pos[0] - (pos[0] // 300) * 300) / 300, (pos[1] - (pos[1] // 300) * 300) / 300]] + # Update the button for new generation. + text3 = font.render(f" I have chosen {len(indices)} images:", True, green, blue) + text3 = pygame.transform.rotate(text3, 90) + scrn.blit(text3, (X*3/4, Y*2/3)) + text3 = font.render(f" New generation!", True, green, blue) + text3 = pygame.transform.rotate(text3, 90) + scrn.blit(text3, (X*3/4+X/16, Y*2/3)) + #text3Rect = text3.get_rect() + #text3Rect.center = (750+750*3/4, 1000) good += [list(latent[index].flatten())] # if event object type is QUIT @@ -156,11 +236,15 @@ status = False # deactivates the pygame library - pygame.quit() - print(indices) + if len(indices) == 0: + print("The user did not like anything! Rerun :-(") + continue + print(f"Clicks at {indices}") os.environ["mu"] = str(len(indices)) forcedlatents = [] bad += [list(latent[u].flatten()) for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]] + if len(bad) > 200: + bad = bad[(len(bad) - 200):] for a in range(llambda): forcedlatent = np.zeros((4, 64, 64)) os.environ["good"] = str(good) @@ -191,3 +275,4 @@ #for uu in range(len(latent)): # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}") +pygame.quit() diff --git a/view_history.sh b/view_history.sh index 82cf3f4e5..c7f4cc715 100755 --- a/view_history.sh +++ b/view_history.sh @@ -2,6 +2,7 @@ #montage `ls -ctr SD*imag*.png | head -n 15 | tail -n 14` -mode concatenate -tile 7x zuck1.png #montage `ls -ctr SD*imag*.png | head -n 29 | tail -n 14` -mode concatenate -tile 7x zuck2.png #montage `ls -ctr SD*imag*.png | tail -n 28` -mode concatenate -tile 7x history.png -montage $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`*.png | sort | tail -n 15 | sort ) -mode concatenate -tile 5x history.png +montage $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`*.png | sort ) -mode concatenate -tile 9x history.png +#montage $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`*.png | sort | tail -n 60 | sort ) -mode concatenate -tile 5x history.png open history.png #cp history.png zuck3.png From 844c3ae25ef57f9225566965fdd954c69465ee27 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Mon, 26 Sep 2022 10:50:15 +0200 Subject: [PATCH 22/76] fix --- README.md | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 4297a663c..792327cb9 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,26 @@ -# Modified evolutionary version. +# Genetic Stable Diffusion -1. Install StableDiffusion as usual, plus a few more stuff. Basically: +## Install StableDiffusion as usual, plus a few more stuff. Basically: +``` conda env create -f environment.yaml conda activate ldm # you can change that name in the environment.yaml file... conda install pytorch torchvision -c pytorch pip install transformers diffusers invisible-watermark pip install pygame +pip install webbrowser pip install nevergrad pip install -e . +``` - +## Hack diffusers (yes I should do that differently... only solution for now). 2. Then use the file "pipeline_stable_diffusion.py" in lieu of the original pipeline_stable_diffusion.py found at << python -c "import diffusers ; print(diffusers.__file__)" >>. This is done as follows: cp pipeline_stable_diffusion.py <>/pipeline_stable_diffusion.py -3. Then edit the prompt in minisd.py, and possibly other variables. - -Then run << python minisd.py >>. +## Then run << python minisd.py >>. +## Send feedback to [**Nevergrad Users**](https://www.facebook.com/groups/nevergradusers/)
# Stable Diffusion *Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:* From b4faf7baf51413b68073952539ca2570d44ad51d Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Mon, 26 Sep 2022 10:52:43 +0200 Subject: [PATCH 23/76] fix --- README.md | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 792327cb9..587692789 100644 --- a/README.md +++ b/README.md @@ -14,9 +14,23 @@ pip install -e . ``` ## Hack diffusers (yes I should do that differently... only solution for now). -2. Then use the file "pipeline_stable_diffusion.py" in lieu of the original pipeline_stable_diffusion.py found at -<< python -c "import diffusers ; print(diffusers.__file__)" >>. This is done as follows: +Copy the file "pipeline_stable_diffusion.py" in lieu of the original pipeline_stable_diffusion.py. + +How to do this ? + First, find where ``diffusers'' is: +``` + python -c "import diffusers ; print(diffusers.__file__)" +``` +or inside python +``` +import diffusers +print(diffusers.__file__) +``` + + Then copy the local file there as follows: +``` cp pipeline_stable_diffusion.py <>/pipeline_stable_diffusion.py +``` ## Then run << python minisd.py >>. From a0fe17b464c0433f3669c8482c165486385f51b3 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Mon, 26 Sep 2022 10:54:31 +0200 Subject: [PATCH 24/76] fix --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 587692789..f974f8eca 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,13 @@ # Genetic Stable Diffusion +This fork of Stable Diffusion uses genetic stuff and a graphical user interface. +It generates many images. +It should work directly on Mac M1. +It should be easy to adap to a machine with GPU. +Without GPU it will be more complicated. +Ping us at the Nevergrad user group if you need help, I'll do my best. +[**Nevergrad Users**](https://www.facebook.com/groups/nevergradusers/)
+ ## Install StableDiffusion as usual, plus a few more stuff. Basically: ``` From a7b95d7f49bde8c78ded04f0d2111ac74fa98bfa Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Mon, 26 Sep 2022 10:57:14 +0200 Subject: [PATCH 25/76] fix --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index f974f8eca..b1fdbfc61 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,12 @@ print(diffusers.__file__) ``` cp pipeline_stable_diffusion.py <>/pipeline_stable_diffusion.py ``` +You can also do a symbolic link: +``` +pushd <> +mv pipeline_stable_diffusion.py backup_pipeline_stable_diffusion.py +ln -s <>/pipeline_stable_diffusion.py . +``` ## Then run << python minisd.py >>. From 0744f86b7eb15ba86a6b030958b53a44faaf9e19 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Mon, 26 Sep 2022 11:02:52 +0200 Subject: [PATCH 26/76] fix --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b1fdbfc61..2b6392290 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ How to do this ? ``` python -c "import diffusers ; print(diffusers.__file__)" ``` +and pipeline_stable_diffusion should at this location + "/pipelines/stable_diffusion/pipeline_stable_diffusion.py" or inside python ``` import diffusers @@ -37,7 +38,7 @@ print(diffusers.__file__) Then copy the local file there as follows: ``` -cp pipeline_stable_diffusion.py <>/pipeline_stable_diffusion.py +cp pipeline_stable_diffusion.py <>/pipeline_stable_diffusion.py ``` You can also do a symbolic link: ``` @@ -47,6 +48,7 @@ ln -s <>/pipeline_stable_diffusion.py . ``` ## Then run << python minisd.py >>. +You should be asked for a prompt (just <> if you like the proposed hardcoded prompt), and then a window should be opened. ## Send feedback to [**Nevergrad Users**](https://www.facebook.com/groups/nevergradusers/)
From 6d5d1c0ba0f5ac54630813614c33409a9853e051 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Mon, 26 Sep 2022 11:11:20 +0200 Subject: [PATCH 27/76] fix --- minisd.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/minisd.py b/minisd.py index 15c302b2d..61ff4239a 100644 --- a/minisd.py +++ b/minisd.py @@ -127,6 +127,9 @@ filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{iteration}_{k}.png" image.save(filename) onlyfiles += [filename] + imp = pygame.transform.scale(pygame.image.load(onlyfiles[-1]).convert(), (300, 300)) + # Using blit to copy content from one surface to other + scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) str_latent = eval((os.environ["latent_sd"])) array_latent = eval(f"np.array(str_latent).reshape(4, 64, 64)") print(f"Debug info: array_latent sumsq/var {sum(array_latent.flatten() ** 2) / len(array_latent.flatten())}") From 3d1458995deb06f7ae8a65ec75adf912447a9e9c Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Mon, 26 Sep 2022 11:24:02 +0200 Subject: [PATCH 28/76] fix --- README.md | 4 +++- minisd.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2b6392290..dc5c02852 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ -# Genetic Stable Diffusion +# Genetic Stable Diffusion. This fork of Stable Diffusion uses genetic stuff and a graphical user interface. +It also works in many languages (tested: French and German, should be ok for many more). It generates many images. It should work directly on Mac M1. It should be easy to adap to a machine with GPU. @@ -18,6 +19,7 @@ pip install transformers diffusers invisible-watermark pip install pygame pip install webbrowser pip install nevergrad +pip install deep-translator pip install -e . ``` diff --git a/minisd.py b/minisd.py index 61ff4239a..589795810 100644 --- a/minisd.py +++ b/minisd.py @@ -5,6 +5,7 @@ from torch import autocast from diffusers import StableDiffusionPipeline import webbrowser +from deep_translator import GoogleTranslator model_id = "CompVis/stable-diffusion-v1-4" @@ -122,7 +123,8 @@ if len(enforcedlatent) > 2: os.environ["forcedlatent"] = enforcedlatent with autocast("cuda"): - image = pipe(prompt, guidance_scale=7.5)["sample"][0] + english_prompt = GoogleTranslator(source='auto', target='en').translate + image = pipe(english_prompt, guidance_scale=7.5)["sample"][0] images += [image] filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{iteration}_{k}.png" image.save(filename) From 36894fa3e803a22cc68c1fe97251b53a2973b501 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Mon, 26 Sep 2022 11:33:30 +0200 Subject: [PATCH 29/76] fix --- minisd.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/minisd.py b/minisd.py index 589795810..f970589a1 100644 --- a/minisd.py +++ b/minisd.py @@ -76,6 +76,10 @@ if len(user_prompt) > 2: prompt = user_prompt +# On the fly translation. +english_prompt = GoogleTranslator(source='auto', target='en').translate + + import os import pygame from os import listdir @@ -123,7 +127,6 @@ if len(enforcedlatent) > 2: os.environ["forcedlatent"] = enforcedlatent with autocast("cuda"): - english_prompt = GoogleTranslator(source='auto', target='en').translate image = pipe(english_prompt, guidance_scale=7.5)["sample"][0] images += [image] filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{iteration}_{k}.png" From 6750d9496307eadee3ec820ccbf50ec7943b50d5 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Mon, 26 Sep 2022 11:46:44 +0200 Subject: [PATCH 30/76] fix --- minisd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/minisd.py b/minisd.py index f970589a1..c8f66abb0 100644 --- a/minisd.py +++ b/minisd.py @@ -157,21 +157,21 @@ # Button for loading a starting point text1 = font.render('Load image', True, green, blue) text1 = pygame.transform.rotate(text1, 90) - scrn.blit(text1, (X*3/4, 0)) + scrn.blit(text1, (X*3/4+X/32, 0)) text1 = font.render('& latent', True, green, blue) text1 = pygame.transform.rotate(text1, 90) scrn.blit(text1, (X*3/4+X/16, 0)) # Button for creating a meme text2 = font.render('Create', True, green, blue) text2 = pygame.transform.rotate(text2, 90) - scrn.blit(text2, (X*3/4, Y/3)) + scrn.blit(text2, (X*3/4+X/32, Y/3)) text2 = font.render('a meme', True, green, blue) text2 = pygame.transform.rotate(text2, 90) scrn.blit(text2, (X*3/4+X/16, Y/3)) # Button for new generation text3 = font.render(f"I don't want to", True, green, blue) text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4, Y*2/3)) + scrn.blit(text3, (X*3/4+X/32, Y*2/3)) text3 = font.render(f"select images! Just rerun.", True, green, blue) text3 = pygame.transform.rotate(text3, 90) scrn.blit(text3, (X*3/4+X/16, Y*2/3)) From 9cacbb758583d8b8bbe931be8175a0a48148bfa2 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Mon, 26 Sep 2022 11:57:58 +0200 Subject: [PATCH 31/76] fix --- minisd.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/minisd.py b/minisd.py index c8f66abb0..d0775ce6d 100644 --- a/minisd.py +++ b/minisd.py @@ -101,9 +101,9 @@ # activate the pygame library . pygame.init() -X = 1900 # > 1500 = buttons +X = 2000 # > 1500 = buttons Y = 900 -scrn = pygame.display.set_mode((X, Y)) +scrn = pygame.display.set_mode((1600, Y)) font = pygame.font.Font('freesansbold.ttf', 32) @@ -113,6 +113,10 @@ onlyfiles = [onlyfiles[f] for f in five_best] for k in range(llambda): if k < len(five_best): + imp = pygame.transform.scale(pygame.image.load(onlyfiles[k]).convert(), (300, 300)) + # Using blit to copy content from one surface to other + scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) + pygame.display.flip() continue text0 = font.render(f'Please wait !!! {k} / {llambda}', True, green, blue) scrn.blit(text0, ((X*3/4)/2, Y/2)) @@ -135,6 +139,7 @@ imp = pygame.transform.scale(pygame.image.load(onlyfiles[-1]).convert(), (300, 300)) # Using blit to copy content from one surface to other scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) + pygame.display.flip() str_latent = eval((os.environ["latent_sd"])) array_latent = eval(f"np.array(str_latent).reshape(4, 64, 64)") print(f"Debug info: array_latent sumsq/var {sum(array_latent.flatten() ** 2) / len(array_latent.flatten())}") @@ -157,24 +162,24 @@ # Button for loading a starting point text1 = font.render('Load image', True, green, blue) text1 = pygame.transform.rotate(text1, 90) - scrn.blit(text1, (X*3/4+X/32, 0)) + scrn.blit(text1, (X*3/4+X/16, 0)) text1 = font.render('& latent', True, green, blue) text1 = pygame.transform.rotate(text1, 90) - scrn.blit(text1, (X*3/4+X/16, 0)) + scrn.blit(text1, (X*3/4+X/16+X/32, 0)) # Button for creating a meme text2 = font.render('Create', True, green, blue) text2 = pygame.transform.rotate(text2, 90) - scrn.blit(text2, (X*3/4+X/32, Y/3)) + scrn.blit(text2, (X*3/4+X/16, Y/3)) text2 = font.render('a meme', True, green, blue) text2 = pygame.transform.rotate(text2, 90) - scrn.blit(text2, (X*3/4+X/16, Y/3)) + scrn.blit(text2, (X*3/4+X/16+X/32, Y/3)) # Button for new generation text3 = font.render(f"I don't want to", True, green, blue) text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4+X/32, Y*2/3)) + scrn.blit(text3, (X*3/4+X/16, Y*2/3)) text3 = font.render(f"select images! Just rerun.", True, green, blue) text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4+X/16, Y*2/3)) + scrn.blit(text3, (X*3/4+X/16+X/32, Y*2/3)) for idx in range(llambda): # set the pygame window name @@ -229,10 +234,11 @@ # Update the button for new generation. text3 = font.render(f" I have chosen {len(indices)} images:", True, green, blue) text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4, Y*2/3)) + scrn.blit(text3, (X*3/4+X/16, Y*2/3)) text3 = font.render(f" New generation!", True, green, blue) text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4+X/16, Y*2/3)) + scrn.blit(text3, (X*3/4+X/16+X/32, Y*2/3)) + pygame.display.flip() #text3Rect = text3.get_rect() #text3Rect.center = (750+750*3/4, 1000) good += [list(latent[index].flatten())] From f75e32f80cea81b32dd6e3b84d0792de4876fda5 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Mon, 26 Sep 2022 12:00:13 +0200 Subject: [PATCH 32/76] fix --- minisd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/minisd.py b/minisd.py index d0775ce6d..d3945d8f0 100644 --- a/minisd.py +++ b/minisd.py @@ -77,7 +77,7 @@ prompt = user_prompt # On the fly translation. -english_prompt = GoogleTranslator(source='auto', target='en').translate +english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt) import os From 7049273287f4a3e772a99938630a7710deca31e1 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Mon, 26 Sep 2022 12:08:06 +0200 Subject: [PATCH 33/76] fix --- minisd.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/minisd.py b/minisd.py index d3945d8f0..62acce88a 100644 --- a/minisd.py +++ b/minisd.py @@ -249,7 +249,16 @@ if i.type == pygame.QUIT: status = False - # deactivates the pygame library + # Using draw.rect module of + # pygame to draw the solid circle + for _ in range(123): + x = np.random.randint(1500) + y = np.random.randint(900) + pygame.draw.circle(scrn, (0, 255, 0), + [x, y], 17, 0) + + # Draws the surface object to the screen. + pygame.display.update() if len(indices) == 0: print("The user did not like anything! Rerun :-(") continue From 54af2a55b5aa4c0c6c70bcbb356806f03df4e2de Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Mon, 26 Sep 2022 17:33:34 +0200 Subject: [PATCH 34/76] fix --- minisd.py | 52 +++++++++++++++++++++--------------- pipeline_stable_diffusion.py | 2 ++ view_history.sh | 6 +++-- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/minisd.py b/minisd.py index 62acce88a..d9bfb3a2e 100644 --- a/minisd.py +++ b/minisd.py @@ -24,6 +24,8 @@ os.environ["enforcedlatent"] = "" os.environ["good"] = "[]" os.environ["bad"] = "[]" +num_iterations = 50 +gs = 7.5 pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token="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") pipe = pipe.to(device) @@ -103,8 +105,8 @@ pygame.init() X = 2000 # > 1500 = buttons Y = 900 -scrn = pygame.display.set_mode((1600, Y)) -font = pygame.font.Font('freesansbold.ttf', 32) +scrn = pygame.display.set_mode((1700, Y)) +font = pygame.font.Font('freesansbold.ttf', 22) for iteration in range(30): @@ -119,7 +121,7 @@ pygame.display.flip() continue text0 = font.render(f'Please wait !!! {k} / {llambda}', True, green, blue) - scrn.blit(text0, ((X*3/4)/2, Y/2)) + scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2)) pygame.display.flip() os.environ["earlystop"] = "False" if k > len(five_best) else "True" os.environ["epsilon"] = str(0. if k == len(five_best) else (k - len(five_best)) / llambda) @@ -131,7 +133,7 @@ if len(enforcedlatent) > 2: os.environ["forcedlatent"] = enforcedlatent with autocast("cuda"): - image = pipe(english_prompt, guidance_scale=7.5)["sample"][0] + image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0] images += [image] filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{iteration}_{k}.png" image.save(filename) @@ -160,37 +162,34 @@ # of specific dimension..e(X, Y). # Button for loading a starting point - text1 = font.render('Load image', True, green, blue) + text1 = font.render('Load image ', True, green, blue) text1 = pygame.transform.rotate(text1, 90) - scrn.blit(text1, (X*3/4+X/16, 0)) - text1 = font.render('& latent', True, green, blue) + scrn.blit(text1, (X*3/4+X/16 - X/32, 0)) + text1 = font.render('& latent ', True, green, blue) text1 = pygame.transform.rotate(text1, 90) - scrn.blit(text1, (X*3/4+X/16+X/32, 0)) + scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, 0)) # Button for creating a meme text2 = font.render('Create', True, green, blue) text2 = pygame.transform.rotate(text2, 90) - scrn.blit(text2, (X*3/4+X/16, Y/3)) + scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) text2 = font.render('a meme', True, green, blue) text2 = pygame.transform.rotate(text2, 90) - scrn.blit(text2, (X*3/4+X/16+X/32, Y/3)) + scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3)) # Button for new generation text3 = font.render(f"I don't want to", True, green, blue) text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4+X/16, Y*2/3)) + scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3)) text3 = font.render(f"select images! Just rerun.", True, green, blue) text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4+X/16+X/32, Y*2/3)) + scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3)) + text4 = font.render(f"Modify parameters !", True, green, blue) + scrn.blit(text4, (300, Y + 30)) + pygame.display.flip() for idx in range(llambda): # set the pygame window name pygame.display.set_caption('images') - - # create a surface object, image is drawn on it. - #imp = pygame.transform.scale(images[idx], (300, 300)) - #imp = pygame.transform.scale(images[idx].convert(), (300, 300)) # TypeError: argument 1 must be pygame.Surface, not Image imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300)) - - # Using blit to copy content from one surface to other scrn.blit(imp, (300 * (idx // 3), 300 * (idx % 3))) # paint screen one time @@ -210,7 +209,16 @@ if i.type == pygame.MOUSEBUTTONUP: pos = pygame.mouse.get_pos() print(f"Click at {pos}") - if pos[0] > 1500: # Not in the images. + if pos[1] > Y: + text4 = font.render(f"ok, go to shell !", True, green, blue) + scrn.blit(text4, (300, Y + 30)) + pygame.display.flip() + num_iterations = int(input(f"Number of iterations ? (current = {num_iterations})\n")) + gs = float(input(f"Guidance scale ? (current = {gs})\n")) + text4 = font.render(f"Ok! parameters changed!", True, green, blue) + scrn.blit(text4, (300, Y + 30)) + pygame.display.flip() + elif pos[0] > 1500: # Not in the images. if pos[1] < Y/3: filename = input("Filename (please provide the latent file, of the format SD*latent*.txt) ?\n") status = False @@ -234,10 +242,10 @@ # Update the button for new generation. text3 = font.render(f" I have chosen {len(indices)} images:", True, green, blue) text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4+X/16, Y*2/3)) - text3 = font.render(f" New generation!", True, green, blue) + scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3)) + text3 = font.render(f" New generation!", True, green, blue) text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4+X/16+X/32, Y*2/3)) + scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3)) pygame.display.flip() #text3Rect = text3.get_rect() #text3Rect.center = (750+750*3/4, 1000) diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index 225216ac7..4ba09aa39 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -1,3 +1,5 @@ +# Modification of the original file by O. Teytaud for facilitating genetic stable diffusion. + import inspect import os import numpy as np diff --git a/view_history.sh b/view_history.sh index c7f4cc715..3f801d2e1 100755 --- a/view_history.sh +++ b/view_history.sh @@ -2,7 +2,9 @@ #montage `ls -ctr SD*imag*.png | head -n 15 | tail -n 14` -mode concatenate -tile 7x zuck1.png #montage `ls -ctr SD*imag*.png | head -n 29 | tail -n 14` -mode concatenate -tile 7x zuck2.png #montage `ls -ctr SD*imag*.png | tail -n 28` -mode concatenate -tile 7x history.png -montage $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`*.png | sort ) -mode concatenate -tile 9x history.png #montage $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`*.png | sort | tail -n 60 | sort ) -mode concatenate -tile 5x history.png -open history.png +#montage $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`_0_11.png | sort ) $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`_0_4.png | sort ) $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`_1_?.png | sort -n ) $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`_1_??.png | sort -n ) -mode concatenate -tile 5x history.png +#montage $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`*.png | sort ) -mode concatenate -tile 5x history.png +#open history.png +open $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`*.png | sort ) #cp history.png zuck3.png From 20803eebfbe5733d510ab08f757a9fb3aabec9f5 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 27 Sep 2022 07:46:48 +0200 Subject: [PATCH 35/76] fix --- README.md | 1 + minisd.py | 44 ++++++++++++++++++++++++++------------------ 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index dc5c02852..4270b74e8 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ pip install transformers diffusers invisible-watermark pip install pygame pip install webbrowser pip install nevergrad +pip install langdetect pip install deep-translator pip install -e . ``` diff --git a/minisd.py b/minisd.py index d9bfb3a2e..5d0dd23c9 100644 --- a/minisd.py +++ b/minisd.py @@ -6,7 +6,7 @@ from diffusers import StableDiffusionPipeline import webbrowser from deep_translator import GoogleTranslator - +from langdetect import detect model_id = "CompVis/stable-diffusion-v1-4" #device = "cuda" @@ -73,14 +73,22 @@ prompt = "A bloody heavy-metal zombie with a chainsaw." prompt = "Tentacles attacking a bloody Meg Myers in Eyptian dress. Meg Myers has a chainsaw." prompt = "Bizarre art." + +prompt = "Beautiful bizarre woman." print(f"The prompt is {prompt}") user_prompt = input("Enter a new prompt if you prefer\n") if len(user_prompt) > 2: prompt = user_prompt # On the fly translation. +language = detect(prompt) english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt) +define to_native(stri): + return GoogleTranslator(source='en', target=language).translate(stri) + +print(f"Working on {english_prompt}, a.k.a {prompt}.") + import os import pygame @@ -120,7 +128,7 @@ scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) pygame.display.flip() continue - text0 = font.render(f'Please wait !!! {k} / {llambda}', True, green, blue) + text0 = font.render(to_native(f'Please wait !!! {k} / {llambda}'), True, green, blue) scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2)) pygame.display.flip() os.environ["earlystop"] = "False" if k > len(five_best) else "True" @@ -169,26 +177,26 @@ text1 = pygame.transform.rotate(text1, 90) scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, 0)) # Button for creating a meme - text2 = font.render('Create', True, green, blue) + text2 = font.render(to_native('Create'), True, green, blue) text2 = pygame.transform.rotate(text2, 90) scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) - text2 = font.render('a meme', True, green, blue) + text2 = font.render(to_native('a meme'), True, green, blue) text2 = pygame.transform.rotate(text2, 90) scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3)) # Button for new generation - text3 = font.render(f"I don't want to", True, green, blue) + text3 = font.render(to_native(f"I don't want to"), True, green, blue) text3 = pygame.transform.rotate(text3, 90) scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3)) - text3 = font.render(f"select images! Just rerun.", True, green, blue) + text3 = font.render(to_native(f"select images! Just rerun."), True, green, blue) text3 = pygame.transform.rotate(text3, 90) scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3)) - text4 = font.render(f"Modify parameters !", True, green, blue) + text4 = font.render(to_native(f"Modify parameters !"), True, green, blue) scrn.blit(text4, (300, Y + 30)) pygame.display.flip() for idx in range(llambda): # set the pygame window name - pygame.display.set_caption('images') + pygame.display.set_caption(prompt) imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300)) scrn.blit(imp, (300 * (idx // 3), 300 * (idx % 3))) @@ -200,7 +208,7 @@ five_best = [] for i in pygame.event.get(): if i.type == pygame.MOUSEBUTTONUP: - print("too early for clicking !!!!") + print(to_native("too early for clicking !!!!")) while (status): # iterate over the list of Event objects @@ -210,26 +218,26 @@ pos = pygame.mouse.get_pos() print(f"Click at {pos}") if pos[1] > Y: - text4 = font.render(f"ok, go to shell !", True, green, blue) + text4 = font.render(to_native(f"ok, go to text window!"), True, green, blue) scrn.blit(text4, (300, Y + 30)) pygame.display.flip() - num_iterations = int(input(f"Number of iterations ? (current = {num_iterations})\n")) - gs = float(input(f"Guidance scale ? (current = {gs})\n")) - text4 = font.render(f"Ok! parameters changed!", True, green, blue) + num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n"))) + gs = float(input(to_native(f"Guidance scale ? (current = {gs})\n"))) + text4 = font.render(to_native(f"Ok! parameters changed!"), True, green, blue) scrn.blit(text4, (300, Y + 30)) pygame.display.flip() elif pos[0] > 1500: # Not in the images. if pos[1] < Y/3: - filename = input("Filename (please provide the latent file, of the format SD*latent*.txt) ?\n") + filename = input(to_native("Filename (please provide the latent file, of the format SD*latent*.txt) ?\n")) status = False with open(filename, 'r') as f: latent = f.read() break if pos[1] < 2*Y/3: url = 'https://imgflip.com/memegenerator' - onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))] + onlyfiles = [f for f in listdir(".") if isfile(join(mypath, f))] onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] - print("Your generated images:") + print(to_native("Your generated images:")) print(onlyfiles) webbrowser.open(url) exit() @@ -240,10 +248,10 @@ five_best += [index] indices += [[index, (pos[0] - (pos[0] // 300) * 300) / 300, (pos[1] - (pos[1] // 300) * 300) / 300]] # Update the button for new generation. - text3 = font.render(f" I have chosen {len(indices)} images:", True, green, blue) + text3 = font.render(to_native(f" I have chosen {len(indices)} images:"), True, green, blue) text3 = pygame.transform.rotate(text3, 90) scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3)) - text3 = font.render(f" New generation!", True, green, blue) + text3 = font.render(to_native(f" New generation!"), True, green, blue) text3 = pygame.transform.rotate(text3, 90) scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3)) pygame.display.flip() From 2bbd905f5c30927d4fb26a1cf681d26975c28d70 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 27 Sep 2022 07:55:47 +0200 Subject: [PATCH 36/76] fix --- minisd.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/minisd.py b/minisd.py index 5d0dd23c9..c4717e80a 100644 --- a/minisd.py +++ b/minisd.py @@ -84,7 +84,7 @@ language = detect(prompt) english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt) -define to_native(stri): +def to_native(stri): return GoogleTranslator(source='en', target=language).translate(stri) print(f"Working on {english_prompt}, a.k.a {prompt}.") @@ -282,6 +282,10 @@ os.environ["mu"] = str(len(indices)) forcedlatents = [] bad += [list(latent[u].flatten()) for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]] + sauron = 0 * latent[0] + for u in [u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]]: + sauron += latent[u] + sauron = (1 / len([u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]])) * sauron if len(bad) > 200: bad = bad[(len(bad) - 200):] for a in range(llambda): @@ -310,6 +314,8 @@ assert i < len(latent[uu][k]), i assert j < len(latent[uu][k][i]), j forcedlatent[k][i][j] = latent[uu][k][i][j] + if a % 2 == 0: + forcedlatent -= np.random.rand() * sauron forcedlatents += [forcedlatent] #for uu in range(len(latent)): # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}") From 29fc38d7b94a713901a42cce22d79686963f3a0c Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 27 Sep 2022 07:57:10 +0200 Subject: [PATCH 37/76] fix --- minisd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/minisd.py b/minisd.py index c4717e80a..48ec414a5 100644 --- a/minisd.py +++ b/minisd.py @@ -75,6 +75,7 @@ prompt = "Bizarre art." prompt = "Beautiful bizarre woman." +prompt = "Yann LeCun as the grim reaper: bizarre art." print(f"The prompt is {prompt}") user_prompt = input("Enter a new prompt if you prefer\n") if len(user_prompt) > 2: From 5b428cb088c95f7d18a4ad46648eaef892f69221 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 27 Sep 2022 09:02:53 +0200 Subject: [PATCH 38/76] fix --- minisd.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/minisd.py b/minisd.py index 48ec414a5..9df57ae8c 100644 --- a/minisd.py +++ b/minisd.py @@ -24,7 +24,7 @@ os.environ["enforcedlatent"] = "" os.environ["good"] = "[]" os.environ["bad"] = "[]" -num_iterations = 50 +num_iterations = 15 gs = 7.5 pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token="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") @@ -76,6 +76,7 @@ prompt = "Beautiful bizarre woman." prompt = "Yann LeCun as the grim reaper: bizarre art." +prompt = "A star with flashy colors." print(f"The prompt is {prompt}") user_prompt = input("Enter a new prompt if you prefer\n") if len(user_prompt) > 2: @@ -114,7 +115,7 @@ def to_native(stri): pygame.init() X = 2000 # > 1500 = buttons Y = 900 -scrn = pygame.display.set_mode((1700, Y)) +scrn = pygame.display.set_mode((1700, Y + 100)) font = pygame.font.Font('freesansbold.ttf', 22) From 60827c662a4cbe7bcf9f0dedca188471bda9aaf2 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 27 Sep 2022 09:05:40 +0200 Subject: [PATCH 39/76] fix --- minisd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/minisd.py b/minisd.py index 9df57ae8c..fda6d13c5 100644 --- a/minisd.py +++ b/minisd.py @@ -24,7 +24,7 @@ os.environ["enforcedlatent"] = "" os.environ["good"] = "[]" os.environ["bad"] = "[]" -num_iterations = 15 +num_iterations = 50 gs = 7.5 pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token="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") From 9a5719a2bd93f4cd7213bfb19c3cf6a5b0c37048 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 27 Sep 2022 12:23:52 +0200 Subject: [PATCH 40/76] fix --- minisd.py | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/minisd.py b/minisd.py index fda6d13c5..72ef821ce 100644 --- a/minisd.py +++ b/minisd.py @@ -2,6 +2,7 @@ import os import torch import numpy as np +import shutil from torch import autocast from diffusers import StableDiffusionPipeline import webbrowser @@ -14,7 +15,9 @@ white = (255, 255, 255) green = (0, 255, 0) +red = (255, 0, 0) blue = (0, 0, 128) +black = (0, 0, 0) os.environ["skl"] = "nn" os.environ["epsilon"] = "0.005" @@ -77,8 +80,12 @@ prompt = "Beautiful bizarre woman." prompt = "Yann LeCun as the grim reaper: bizarre art." prompt = "A star with flashy colors." +prompt = "Un chat en sang et en armure joue de la batterie." +prompt = "Judith beheading Holofernes." print(f"The prompt is {prompt}") -user_prompt = input("Enter a new prompt if you prefer\n") + +print(f"Francais: Proposez un nouveau texte si vous ne voulez pas dessiner << {prompt} >>.\n") +user_prompt = input(f"English: Enter a new prompt if you prefer something else than << {prompt} >>.\n") if len(user_prompt) > 2: prompt = user_prompt @@ -120,12 +127,14 @@ def to_native(stri): for iteration in range(30): + scrn.fill(black) latent = [latent[f] for f in five_best] images = [images[f] for f in five_best] onlyfiles = [onlyfiles[f] for f in five_best] for k in range(llambda): if k < len(five_best): imp = pygame.transform.scale(pygame.image.load(onlyfiles[k]).convert(), (300, 300)) + shutil.copyfile(onlyfiles[k], to_native("Selected") + onlyfiles[k]) # Using blit to copy content from one surface to other scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) pygame.display.flip() @@ -145,7 +154,7 @@ def to_native(stri): with autocast("cuda"): image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0] images += [image] - filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{iteration}_{k}.png" + filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{iteration:05d}_{k:05d}.png" image.save(filename) onlyfiles += [filename] imp = pygame.transform.scale(pygame.image.load(onlyfiles[-1]).convert(), (300, 300)) @@ -171,6 +180,12 @@ def to_native(stri): # create the display surface object # of specific dimension..e(X, Y). + # Add rectangles + pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 0, X*3/4+X/16+X/32, Y/3), 2) + pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, Y/3, X*3/4+X/16+X/32, 2*Y/3), 2) + pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y), 2) + pygame.draw.rect(scrn, red, pygame.Rect(0, Y, X/2, Y+100), 2) + # Button for loading a starting point text1 = font.render('Load image ', True, green, blue) text1 = pygame.transform.rotate(text1, 90) @@ -178,6 +193,7 @@ def to_native(stri): text1 = font.render('& latent ', True, green, blue) text1 = pygame.transform.rotate(text1, 90) scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, 0)) + # Button for creating a meme text2 = font.render(to_native('Create'), True, green, blue) text2 = pygame.transform.rotate(text2, 90) @@ -186,13 +202,13 @@ def to_native(stri): text2 = pygame.transform.rotate(text2, 90) scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3)) # Button for new generation - text3 = font.render(to_native(f"I don't want to"), True, green, blue) + text3 = font.render(to_native(f"I don't want to select images"), True, green, blue) text3 = pygame.transform.rotate(text3, 90) scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3)) - text3 = font.render(to_native(f"select images! Just rerun."), True, green, blue) + text3 = font.render(to_native(f"Just rerun."), True, green, blue) text3 = pygame.transform.rotate(text3, 90) scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3)) - text4 = font.render(to_native(f"Modify parameters !"), True, green, blue) + text4 = font.render(to_native(f"Modify parameters or text!"), True, green, blue) scrn.blit(text4, (300, Y + 30)) pygame.display.flip() @@ -225,6 +241,11 @@ def to_native(stri): pygame.display.flip() num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n"))) gs = float(input(to_native(f"Guidance scale ? (current = {gs})\n"))) + new_prompt = str(input(to_native(f"Enter a text if you want to change from ") + prompt)) + if len(new_prompt) > 2: + prompt = new_prompt + language = detect(prompt) + english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt) text4 = font.render(to_native(f"Ok! parameters changed!"), True, green, blue) scrn.blit(text4, (300, Y + 30)) pygame.display.flip() @@ -296,7 +317,7 @@ def to_native(stri): os.environ["bad"] = str(bad) coefficients = np.zeros(len(indices)) for i in range(len(indices)): - coefficients[i] = np.exp(np.random.randn()) + coefficients[i] = np.exp(2. * np.random.randn()) for i in range(64): x = i / 63. for j in range(64): From 9d9572af7543adb50ebcf4f25e04a7554b98b8c2 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 27 Sep 2022 14:04:11 +0200 Subject: [PATCH 41/76] fix --- minisd.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/minisd.py b/minisd.py index 72ef821ce..8b266f61a 100644 --- a/minisd.py +++ b/minisd.py @@ -30,6 +30,8 @@ num_iterations = 50 gs = 7.5 +forcedlatents = [] + pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token="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") pipe = pipe.to(device) @@ -82,6 +84,7 @@ prompt = "A star with flashy colors." prompt = "Un chat en sang et en armure joue de la batterie." prompt = "Judith beheading Holofernes." +prompt = "Woman beheading a man: cyberpunk style." print(f"The prompt is {prompt}") print(f"Francais: Proposez un nouveau texte si vous ne voulez pas dessiner << {prompt} >>.\n") @@ -127,10 +130,11 @@ def to_native(stri): for iteration in range(30): - scrn.fill(black) + #scrn.fill(black) latent = [latent[f] for f in five_best] images = [images[f] for f in five_best] onlyfiles = [onlyfiles[f] for f in five_best] + early_stop = [] for k in range(llambda): if k < len(five_best): imp = pygame.transform.scale(pygame.image.load(onlyfiles[k]).convert(), (300, 300)) @@ -139,14 +143,20 @@ def to_native(stri): scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) pygame.display.flip() continue + if len(early_stop) > 0: + break text0 = font.render(to_native(f'Please wait !!! {k} / {llambda}'), True, green, blue) - scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2)) + scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/8)) + text0 = font.render(to_native(f'Or click on an image (then don''t move the mouse until click received!),'), True, green, blue) + scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2)) + text0 = font.render(to_native(f'for rerunning on a specific image.'), True, green, blue) + scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2+Y/8)) pygame.display.flip() os.environ["earlystop"] = "False" if k > len(five_best) else "True" os.environ["epsilon"] = str(0. if k == len(five_best) else (k - len(five_best)) / llambda) os.environ["budget"] = str(300 if k > len(five_best) else 2) os.environ["skl"] = {0: "nn", 1: "tree", 2: "logit"}[k % 3] - if iteration > 0: + if iteration > 0 and len(forcedlatents) > 0: os.environ["forcedlatent"] = str(list(forcedlatents[k].flatten())) enforcedlatent = os.environ.get("enforcedlatent", "") if len(enforcedlatent) > 2: @@ -167,6 +177,15 @@ def to_native(stri): latent += [array_latent] with open(f"SD_{prompt.replace(' ','_')}_latent_{sentinel}_{k}.txt", 'w') as f: f.write(f"{latent}") + # In case of early stopping. + for i in pygame.event.get(): + if i.type == pygame.MOUSEBUTTONUP: + pos = pygame.mouse.get_pos() + index = 3 * (pos[0] // 300) + (pos[1] // 300) + if index <= k: + print(to_native("You clicked for requesting an early stopping.")) + early_stop = [pos] + break # Stop the forcing from disk! os.environ["enforcedlatent"] = "" @@ -226,14 +245,14 @@ def to_native(stri): five_best = [] for i in pygame.event.get(): if i.type == pygame.MOUSEBUTTONUP: - print(to_native("too early for clicking !!!!")) + print(to_native(".... too early for clicking !!!!")) while (status): # iterate over the list of Event objects # that was returned by pygame.event.get() method. - for i in pygame.event.get(): - if i.type == pygame.MOUSEBUTTONUP: - pos = pygame.mouse.get_pos() + for i in early_stop + pygame.event.get(): + if hasattr(i, "type") and i.type == pygame.MOUSEBUTTONUP or len(early_stop) > 0: + pos = early_stop[0] if len(early_stop) > 0 else pygame.mouse.get_pos() print(f"Click at {pos}") if pos[1] > Y: text4 = font.render(to_native(f"ok, go to text window!"), True, green, blue) From 4a1789d38e04ae7833b9d11fc3bedd78f6ca2ecc Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 27 Sep 2022 16:51:27 +0200 Subject: [PATCH 42/76] fix --- minisd.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/minisd.py b/minisd.py index 8b266f61a..be64eb72c 100644 --- a/minisd.py +++ b/minisd.py @@ -83,8 +83,7 @@ prompt = "Yann LeCun as the grim reaper: bizarre art." prompt = "A star with flashy colors." prompt = "Un chat en sang et en armure joue de la batterie." -prompt = "Judith beheading Holofernes." -prompt = "Woman beheading a man: cyberpunk style." +prompt = "Cyberpunk photographic version of Judith beheading Holofernes." print(f"The prompt is {prompt}") print(f"Francais: Proposez un nouveau texte si vous ne voulez pas dessiner << {prompt} >>.\n") @@ -231,11 +230,12 @@ def to_native(stri): scrn.blit(text4, (300, Y + 30)) pygame.display.flip() - for idx in range(llambda): - # set the pygame window name - pygame.display.set_caption(prompt) - imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300)) - scrn.blit(imp, (300 * (idx // 3), 300 * (idx % 3))) + if len(early_stop) == 0: + for idx in range(llambda): + # set the pygame window name + pygame.display.set_caption(prompt) + imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300)) + scrn.blit(imp, (300 * (idx // 3), 300 * (idx % 3))) # paint screen one time pygame.display.flip() From 0ec465aa91a0710e6a2db67623c43e750d0ef569 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 27 Sep 2022 17:15:59 +0200 Subject: [PATCH 43/76] fix --- minisd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/minisd.py b/minisd.py index be64eb72c..4f5a4412d 100644 --- a/minisd.py +++ b/minisd.py @@ -304,7 +304,7 @@ def to_native(stri): # if event object type is QUIT # then quitting the pygame # and program both. - if i.type == pygame.QUIT: + if len(early_stop) > 0 or i.type == pygame.QUIT: status = False # Using draw.rect module of From 05f6740c93d87a070354dc3ee976e17992ad64c6 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 27 Sep 2022 18:53:14 +0200 Subject: [PATCH 44/76] fix --- README.md | 1 + minisd.py | 26 +++++++++++++++++++++++++- pipeline_stable_diffusion.py | 3 +++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4270b74e8..3cf23c988 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ conda activate ldm # you can change that name in the environment.yaml file... conda install pytorch torchvision -c pytorch pip install transformers diffusers invisible-watermark pip install pygame +pip install einops pip install webbrowser pip install nevergrad pip install langdetect diff --git a/minisd.py b/minisd.py index 4f5a4412d..70ba3bcea 100644 --- a/minisd.py +++ b/minisd.py @@ -3,12 +3,15 @@ import torch import numpy as np import shutil +import PIL +from PIL import Image +from einops import rearrange, repeat from torch import autocast from diffusers import StableDiffusionPipeline import webbrowser from deep_translator import GoogleTranslator from langdetect import detect - +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" model_id = "CompVis/stable-diffusion-v1-4" #device = "cuda" device = "mps" #torch.device("mps") @@ -84,6 +87,7 @@ prompt = "A star with flashy colors." prompt = "Un chat en sang et en armure joue de la batterie." prompt = "Cyberpunk photographic version of Judith beheading Holofernes." +prompt = "Photo of a cyberpunk Mark Zuckerberg killing Cthulhu with a light saber." print(f"The prompt is {prompt}") print(f"Francais: Proposez un nouveau texte si vous ne voulez pas dessiner << {prompt} >>.\n") @@ -127,6 +131,26 @@ def to_native(stri): scrn = pygame.display.set_mode((1700, Y + 100)) font = pygame.font.Font('freesansbold.ttf', 22) +image_name = input(to_native("Name of image for starting ? (enter if no start image)")) +def load_img(path): + image = Image.open(path).convert("RGB") + w, h = image.size + print(f"loaded input image of size ({w}, {h}) from {path}") + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.*image - 1. + +if len(image_name) > 0: + import torchvision + #forced_latent = pipe.get_latent(torchvision.io.read_image(image_name).float()) + model = pipe.vae + init_image = load_img(image_name).to(device) + init_image = repeat(init_image, '1 ... -> b ...', b=1) + #forced_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) + forced_latent = model.encode(init_image) for iteration in range(30): #scrn.fill(black) diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index 4ba09aa39..0bef76aec 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -95,6 +95,9 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) +# def get_latent(self, image): +# return self.vae.encode(image) + @torch.no_grad() def __call__( self, From 411200cae9d1e77da7b52db1aeaad4c308c4af29 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 27 Sep 2022 19:18:03 +0200 Subject: [PATCH 45/76] fix --- minisd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/minisd.py b/minisd.py index 70ba3bcea..b7e70c60e 100644 --- a/minisd.py +++ b/minisd.py @@ -151,6 +151,7 @@ def load_img(path): init_image = repeat(init_image, '1 ... -> b ...', b=1) #forced_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) forced_latent = model.encode(init_image) + os.environ["forcedlatent"] = str(list(forced_latent.flatten())) for iteration in range(30): #scrn.fill(black) From baca5b4e19b4952a3229b782e75bd929b686821a Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 27 Sep 2022 21:00:47 +0200 Subject: [PATCH 46/76] fix --- minisd.py | 13 +++++++++---- pipeline_stable_diffusion.py | 7 ++++++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/minisd.py b/minisd.py index b7e70c60e..6a7fd471f 100644 --- a/minisd.py +++ b/minisd.py @@ -88,6 +88,7 @@ prompt = "Un chat en sang et en armure joue de la batterie." prompt = "Cyberpunk photographic version of Judith beheading Holofernes." prompt = "Photo of a cyberpunk Mark Zuckerberg killing Cthulhu with a light saber." +prompt = "Photo of Mark Zuckerberg killing Cthulhu with a light saber." print(f"The prompt is {prompt}") print(f"Francais: Proposez un nouveau texte si vous ne voulez pas dessiner << {prompt} >>.\n") @@ -137,7 +138,8 @@ def load_img(path): w, h = image.size print(f"loaded input image of size ({w}, {h}) from {path}") w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = image.resize((512, 512), resample=PIL.Image.LANCZOS) + #image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) @@ -150,8 +152,11 @@ def load_img(path): init_image = load_img(image_name).to(device) init_image = repeat(init_image, '1 ... -> b ...', b=1) #forced_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) - forced_latent = model.encode(init_image) - os.environ["forcedlatent"] = str(list(forced_latent.flatten())) + #forced_latent = model.encode(init_image) + forced_latent = model.encode(init_image.to(device)).latent_dist.sample() + print(forced_latent.shape) + os.environ["forcedlatent"] = str(list(forced_latent.flatten().cpu().detach().numpy())) + forcedlatents = [forced_latent for _ in range(llambda)] for iteration in range(30): #scrn.fill(black) @@ -171,7 +176,7 @@ def load_img(path): break text0 = font.render(to_native(f'Please wait !!! {k} / {llambda}'), True, green, blue) scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/8)) - text0 = font.render(to_native(f'Or click on an image (then don''t move the mouse until click received!),'), True, green, blue) + text0 = font.render(to_native(f'Or (EMERGENCY STOP BECAUSE BORED!) click on an image (THEN DON''T MOVE the mouse until click received!),'), True, green, blue) scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2)) text0 = font.render(to_native(f'for rerunning on a specific image.'), True, green, blue) scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2+Y/8)) diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index 0bef76aec..481aa9f30 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -220,6 +220,11 @@ def __call__( generator=generator, device=latents_device, ) + if len(os.environ["forcedlatent"]) > 0: + print("we get a forcing for the latent z.") + latents = np.array(eval(os.environ["forcedlatent"])) + latents = np.sqrt(len(latents)) * latents / np.sum(latents ** 2) + latents = torch.from_numpy(np.array(eval(os.environ["forcedlatent"])).reshape((1,4,64,64))) good = eval(os.environ["good"]) bad = eval(os.environ["bad"]) print(f"{len(good)} good and {len(bad)} bad") @@ -315,7 +320,7 @@ def loss(x): os.environ["latent_sd"] = str(list(latents.flatten().cpu().numpy())) for i in [2, 3]: latents = torch.repeat_interleave(latents, repeats=latents_shape[i] // latents_intermediate_shape[i], dim=i) #/ np.sqrt(np.sqrt(latents_shape[i] // latents_intermediate_shape[i])) - latents = latents.to(self.device) + latents = latents.float().to(self.device) # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) From f19b1b3d2ae4da27acb60e65d77daf09e9aeecc6 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 27 Sep 2022 21:05:49 +0200 Subject: [PATCH 47/76] fix --- pipeline_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index 481aa9f30..f62ea037e 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -223,7 +223,7 @@ def __call__( if len(os.environ["forcedlatent"]) > 0: print("we get a forcing for the latent z.") latents = np.array(eval(os.environ["forcedlatent"])) - latents = np.sqrt(len(latents)) * latents / np.sum(latents ** 2) + latents = np.sqrt(len(latents)) * latents / np.sqrt(np.sum(latents ** 2)) latents = torch.from_numpy(np.array(eval(os.environ["forcedlatent"])).reshape((1,4,64,64))) good = eval(os.environ["good"]) bad = eval(os.environ["bad"]) From 0356ceb17f4951b57a2c9452d4127e335afcadf6 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Wed, 28 Sep 2022 10:11:44 +0200 Subject: [PATCH 48/76] this_version_ok --- README.md | 1 + minisd.py | 65 ++++++++++++++++++++++++++---------- pipeline_stable_diffusion.py | 7 ++-- 3 files changed, 53 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 3cf23c988..d9fcdc7c3 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ pip install transformers diffusers invisible-watermark pip install pygame pip install einops pip install webbrowser +pip install pyfiglet pip install nevergrad pip install langdetect pip install deep-translator diff --git a/minisd.py b/minisd.py index 6a7fd471f..171fd00a6 100644 --- a/minisd.py +++ b/minisd.py @@ -27,7 +27,7 @@ os.environ["decay"] = "0." os.environ["ngoptim"] = "DiscreteLenglerOnePlusOne" os.environ["forcedlatent"] = "" -os.environ["enforcedlatent"] = "" +#os.environ["enforcedlatent"] = "" os.environ["good"] = "[]" os.environ["bad"] = "[]" num_iterations = 50 @@ -88,9 +88,17 @@ prompt = "Un chat en sang et en armure joue de la batterie." prompt = "Cyberpunk photographic version of Judith beheading Holofernes." prompt = "Photo of a cyberpunk Mark Zuckerberg killing Cthulhu with a light saber." +prompt = "A ferocious cyborg bear." prompt = "Photo of Mark Zuckerberg killing Cthulhu with a light saber." print(f"The prompt is {prompt}") + +import pyfiglet +print(pyfiglet.figlet_format("Welcome in Genetic Stable Diffusion !")) +print(pyfiglet.figlet_format("First, let us choose the text :-)!")) + + + print(f"Francais: Proposez un nouveau texte si vous ne voulez pas dessiner << {prompt} >>.\n") user_prompt = input(f"English: Enter a new prompt if you prefer something else than << {prompt} >>.\n") if len(user_prompt) > 2: @@ -103,7 +111,10 @@ def to_native(stri): return GoogleTranslator(source='en', target=language).translate(stri) -print(f"Working on {english_prompt}, a.k.a {prompt}.") +def pretty_print(stri): + print(pyfiglet.figlet_format(to_native(stri))) + +print(f"{to_native('Working on')} {english_prompt}, a.k.a {prompt}.") import os @@ -125,6 +136,9 @@ def to_native(stri): images = [] onlyfiles = [] +pretty_print("Now let us choose (if you want) an image as a start.") +image_name = input(to_native("Name of image for starting ? (enter if no start image)")) + # activate the pygame library . pygame.init() X = 2000 # > 1500 = buttons @@ -132,11 +146,10 @@ def to_native(stri): scrn = pygame.display.set_mode((1700, Y + 100)) font = pygame.font.Font('freesansbold.ttf', 22) -image_name = input(to_native("Name of image for starting ? (enter if no start image)")) def load_img(path): image = Image.open(path).convert("RGB") w, h = image.size - print(f"loaded input image of size ({w}, {h}) from {path}") + print(to_native(f"loaded input image of size ({w}, {h}) from {path}")) w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 image = image.resize((512, 512), resample=PIL.Image.LANCZOS) #image = image.resize((w, h), resample=PIL.Image.LANCZOS) @@ -155,8 +168,22 @@ def load_img(path): #forced_latent = model.encode(init_image) forced_latent = model.encode(init_image.to(device)).latent_dist.sample() print(forced_latent.shape) - os.environ["forcedlatent"] = str(list(forced_latent.flatten().cpu().detach().numpy())) - forcedlatents = [forced_latent for _ in range(llambda)] + #os.environ["forcedlatent"] = str(list(forced_latent.flatten().cpu().detach().numpy())) + # Forced latent is for after os.environ["forcedlatent"]... + forcedlatents = [] + new_fl = forced_latent.cpu().detach().numpy().flatten() + basic_new_fl = np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) + for i in range(llambda): + #new_fl = forced_latent + (1. / 1.1**(llambda-i)) * torch.from_numpy(np.random.randn(1*4*64*64).reshape(1,4,64,64)).float().to(device) + #forcedlatents += [new_fl.cpu().detach().numpy()] + if i > 0: + epsilon = 0.3 / 1.1**i + new_fl = epsilon * basic_new_fl + (1 - epsilon) * np.random.randn(1*4*64*64) + else: + new_fl = basic_new_fl + new_fl = np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) + forcedlatents += [new_fl] + #print(f"{i} --> {forcedlatents[i][:10]}") for iteration in range(30): #scrn.fill(black) @@ -165,6 +192,8 @@ def load_img(path): onlyfiles = [onlyfiles[f] for f in five_best] early_stop = [] for k in range(llambda): + if len(forcedlatents) > 0 and k < len(forcedlatents): + os.environ["forcedlatent"] = str(list(forcedlatents[k].flatten())) if k < len(five_best): imp = pygame.transform.scale(pygame.image.load(onlyfiles[k]).convert(), (300, 300)) shutil.copyfile(onlyfiles[k], to_native("Selected") + onlyfiles[k]) @@ -174,6 +203,8 @@ def load_img(path): continue if len(early_stop) > 0: break + pygame.draw.rect(scrn, black, pygame.Rect(0, Y, 1700, Y+100)) + pygame.draw.rect(scrn, black, pygame.Rect(X, 0, 1700, Y+100)) text0 = font.render(to_native(f'Please wait !!! {k} / {llambda}'), True, green, blue) scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/8)) text0 = font.render(to_native(f'Or (EMERGENCY STOP BECAUSE BORED!) click on an image (THEN DON''T MOVE the mouse until click received!),'), True, green, blue) @@ -185,11 +216,10 @@ def load_img(path): os.environ["epsilon"] = str(0. if k == len(five_best) else (k - len(five_best)) / llambda) os.environ["budget"] = str(300 if k > len(five_best) else 2) os.environ["skl"] = {0: "nn", 1: "tree", 2: "logit"}[k % 3] - if iteration > 0 and len(forcedlatents) > 0: - os.environ["forcedlatent"] = str(list(forcedlatents[k].flatten())) - enforcedlatent = os.environ.get("enforcedlatent", "") - if len(enforcedlatent) > 2: - os.environ["forcedlatent"] = enforcedlatent + #enforcedlatent = os.environ.get("enforcedlatent", "") + #if len(enforcedlatent) > 2: + # os.environ["forcedlatent"] = enforcedlatent + # os.environ["enforcedlatent"] = "" with autocast("cuda"): image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0] images += [image] @@ -212,12 +242,12 @@ def load_img(path): pos = pygame.mouse.get_pos() index = 3 * (pos[0] // 300) + (pos[1] // 300) if index <= k: - print(to_native("You clicked for requesting an early stopping.")) + pretty_print(("You clicked for requesting an early stopping.")) early_stop = [pos] break # Stop the forcing from disk! - os.environ["enforcedlatent"] = "" + #os.environ["enforcedlatent"] = "" # importing required library #mypath = "./" @@ -283,8 +313,9 @@ def load_img(path): for i in early_stop + pygame.event.get(): if hasattr(i, "type") and i.type == pygame.MOUSEBUTTONUP or len(early_stop) > 0: pos = early_stop[0] if len(early_stop) > 0 else pygame.mouse.get_pos() - print(f"Click at {pos}") + pretty_print(f"Detected! Click at {pos}") if pos[1] > Y: + pretty_print("Let us update parameters!") text4 = font.render(to_native(f"ok, go to text window!"), True, green, blue) scrn.blit(text4, (300, Y + 30)) pygame.display.flip() @@ -295,6 +326,7 @@ def load_img(path): prompt = new_prompt language = detect(prompt) english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt) + pretty_print("Ok! Parameters updated.") text4 = font.render(to_native(f"Ok! parameters changed!"), True, green, blue) scrn.blit(text4, (300, Y + 30)) pygame.display.flip() @@ -337,15 +369,12 @@ def load_img(path): if len(early_stop) > 0 or i.type == pygame.QUIT: status = False - # Using draw.rect module of - # pygame to draw the solid circle + # Covering old images with full circles. for _ in range(123): x = np.random.randint(1500) y = np.random.randint(900) pygame.draw.circle(scrn, (0, 255, 0), [x, y], 17, 0) - - # Draws the surface object to the screen. pygame.display.update() if len(indices) == 0: print("The user did not like anything! Rerun :-(") diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index f62ea037e..6891c599c 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -222,9 +222,12 @@ def __call__( ) if len(os.environ["forcedlatent"]) > 0: print("we get a forcing for the latent z.") - latents = np.array(eval(os.environ["forcedlatent"])) + latents = np.array(eval(os.environ["forcedlatent"])).flatten() latents = np.sqrt(len(latents)) * latents / np.sqrt(np.sum(latents ** 2)) - latents = torch.from_numpy(np.array(eval(os.environ["forcedlatent"])).reshape((1,4,64,64))) + print(latents[:10]) + print(f"immediately after loading latent ==> {sum(latents.flatten()**2) / len(latents.flatten())}") + latents = torch.from_numpy(latents.reshape((1,4,64,64))).float().to(latents_device) + os.environ["forcedlatent"] = "" good = eval(os.environ["good"]) bad = eval(os.environ["bad"]) print(f"{len(good)} good and {len(bad)} bad") From 55a08311ca9b6116fdce9a5cbb58ff0d125eca03 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Wed, 28 Sep 2022 10:16:07 +0200 Subject: [PATCH 49/76] fix --- README.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d9fcdc7c3..a35ea5c4b 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,13 @@ It should work directly on Mac M1. It should be easy to adap to a machine with GPU. Without GPU it will be more complicated. Ping us at the Nevergrad user group if you need help, I'll do my best. + [**Nevergrad Users**](https://www.facebook.com/groups/nevergradusers/)
+[**Doc**](https://docs.google.com/document/d/12Bz095QNuo_ojxSlGENXKL5Law75IUx5_Nm5L5guKgo/edit?usp=sharing)
+ + + + ## Install StableDiffusion as usual, plus a few more stuff. Basically: @@ -34,8 +40,9 @@ How to do this ? ``` python -c "import diffusers ; print(diffusers.__file__)" ``` -and pipeline_stable_diffusion should at this location + "/pipelines/stable_diffusion/pipeline_stable_diffusion.py" -or inside python +and pipeline_stable_diffusion should be copied at this location + "/pipelines/stable_diffusion/pipeline_stable_diffusion.py" (overwrite that file). + +Or inside python ``` import diffusers print(diffusers.__file__) From c60729dc4b0cfa6456a0dcdec678e0ae6e94385e Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Wed, 28 Sep 2022 15:01:31 +0200 Subject: [PATCH 50/76] fix --- minisd.py | 18 ++++++++++++++---- pipeline_stable_diffusion.py | 7 +++++-- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/minisd.py b/minisd.py index 171fd00a6..97a62574b 100644 --- a/minisd.py +++ b/minisd.py @@ -86,10 +86,14 @@ prompt = "Yann LeCun as the grim reaper: bizarre art." prompt = "A star with flashy colors." prompt = "Un chat en sang et en armure joue de la batterie." -prompt = "Cyberpunk photographic version of Judith beheading Holofernes." prompt = "Photo of a cyberpunk Mark Zuckerberg killing Cthulhu with a light saber." prompt = "A ferocious cyborg bear." prompt = "Photo of Mark Zuckerberg killing Cthulhu with a light saber." +prompt = "A bear with horns and blood and big teeth." +prompt = "A photo of a bear and Yoda, good friends." +prompt = "A photo of Yoda on the left, a blue octopus on the right, an explosion in the center." +prompt = "A bird is on a hippo. They fight a black and red octopus. Jungle in the background." +prompt = "A flying white bird behind 4 colored pots with fire." print(f"The prompt is {prompt}") @@ -204,10 +208,12 @@ def load_img(path): if len(early_stop) > 0: break pygame.draw.rect(scrn, black, pygame.Rect(0, Y, 1700, Y+100)) - pygame.draw.rect(scrn, black, pygame.Rect(X, 0, 1700, Y+100)) + pygame.draw.rect(scrn, black, pygame.Rect(1500, 0, 2000, Y+100)) text0 = font.render(to_native(f'Please wait !!! {k} / {llambda}'), True, green, blue) - scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/8)) - text0 = font.render(to_native(f'Or (EMERGENCY STOP BECAUSE BORED!) click on an image (THEN DON''T MOVE the mouse until click received!),'), True, green, blue) + scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/4)) + text0 = font.render(to_native(f'Or, if you find one image very cool and want to focus on it only,'), True, green, blue) + scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2-Y/8)) + text0 = font.render(to_native(f'then click on it AND KEEP THE MOUSE AT THE SAME POINT until I get the click.),'), True, green, blue) scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2)) text0 = font.render(to_native(f'for rerunning on a specific image.'), True, green, blue) scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2+Y/8)) @@ -321,6 +327,7 @@ def load_img(path): pygame.display.flip() num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n"))) gs = float(input(to_native(f"Guidance scale ? (current = {gs})\n"))) + print(f"The current text is << {prompt} >>.") new_prompt = str(input(to_native(f"Enter a text if you want to change from ") + prompt)) if len(new_prompt) > 2: prompt = new_prompt @@ -338,6 +345,7 @@ def load_img(path): latent = f.read() break if pos[1] < 2*Y/3: + pretty_print("Let us create a meme!") url = 'https://imgflip.com/memegenerator' onlyfiles = [f for f in listdir(".") if isfile(join(mypath, f))] onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] @@ -352,6 +360,8 @@ def load_img(path): five_best += [index] indices += [[index, (pos[0] - (pos[0] // 300) * 300) / 300, (pos[1] - (pos[1] // 300) * 300) / 300]] # Update the button for new generation. + pygame.draw.rect(scrn, black, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y)) + pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y), 2) text3 = font.render(to_native(f" I have chosen {len(indices)} images:"), True, green, blue) text3 = pygame.transform.rotate(text3, 90) scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3)) diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index 6891c599c..d8166748f 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -214,6 +214,7 @@ def __call__( latents_device = "cpu" if self.device.type == "mps" else self.device latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) latents_intermediate_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + speedup = 1 if latents is None: latents = torch.randn( latents_intermediate_shape, @@ -222,6 +223,7 @@ def __call__( ) if len(os.environ["forcedlatent"]) > 0: print("we get a forcing for the latent z.") + speedup = 1 latents = np.array(eval(os.environ["forcedlatent"])).flatten() latents = np.sqrt(len(latents)) * latents / np.sqrt(np.sum(latents ** 2)) print(latents[:10]) @@ -232,7 +234,7 @@ def __call__( bad = eval(os.environ["bad"]) print(f"{len(good)} good and {len(bad)} bad") i_believe_in_evolution = len(good) > 0 and len(bad) > 0 - #i_believe_in_evolution = False + i_believe_in_evolution = False print(f"I believe in evolution = {i_believe_in_evolution}") if i_believe_in_evolution: from sklearn import tree @@ -331,7 +333,7 @@ def loss(x): if accepts_offset: extra_set_kwargs["offset"] = 1 - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + self.scheduler.set_timesteps(num_inference_steps // speedup, **extra_set_kwargs) # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): @@ -369,6 +371,7 @@ def loss(x): latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae + os.environ["latent_sd"] = str(list(latents.flatten().cpu().detach().numpy())) latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample From cd71c84b4447aca27b0e5ee89dc65c8ad93702fe Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Wed, 28 Sep 2022 17:20:29 +0200 Subject: [PATCH 51/76] fix --- minisd.py | 38 +++++++++++++++++++++++------------- pipeline_stable_diffusion.py | 5 ++--- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/minisd.py b/minisd.py index 97a62574b..878aa3407 100644 --- a/minisd.py +++ b/minisd.py @@ -93,7 +93,7 @@ prompt = "A photo of a bear and Yoda, good friends." prompt = "A photo of Yoda on the left, a blue octopus on the right, an explosion in the center." prompt = "A bird is on a hippo. They fight a black and red octopus. Jungle in the background." -prompt = "A flying white bird behind 4 colored pots with fire." +prompt = "A flying white owl above 4 colored pots with fire. The owl has a hat." print(f"The prompt is {prompt}") @@ -271,12 +271,12 @@ def load_img(path): pygame.draw.rect(scrn, red, pygame.Rect(0, Y, X/2, Y+100), 2) # Button for loading a starting point - text1 = font.render('Load image ', True, green, blue) + text1 = font.render('Manually edit an image.', True, green, blue) text1 = pygame.transform.rotate(text1, 90) - scrn.blit(text1, (X*3/4+X/16 - X/32, 0)) - text1 = font.render('& latent ', True, green, blue) - text1 = pygame.transform.rotate(text1, 90) - scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, 0)) + #scrn.blit(text1, (X*3/4+X/16 - X/32, 0)) + #text1 = font.render('& latent ', True, green, blue) + #text1 = pygame.transform.rotate(text1, 90) + #scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, 0)) # Button for creating a meme text2 = font.render(to_native('Create'), True, green, blue) @@ -325,12 +325,19 @@ def load_img(path): text4 = font.render(to_native(f"ok, go to text window!"), True, green, blue) scrn.blit(text4, (300, Y + 30)) pygame.display.flip() - num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n"))) + try: + num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n"))) + except: + num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n"))) gs = float(input(to_native(f"Guidance scale ? (current = {gs})\n"))) - print(f"The current text is << {prompt} >>.") + print(to_native(f"The current text is << {prompt} >>.")) + print(to_native("Start your answer with a symbol << + >> if this is an edit and not a new text.")) new_prompt = str(input(to_native(f"Enter a text if you want to change from ") + prompt)) if len(new_prompt) > 2: - prompt = new_prompt + if new_prompt[0] == "+": + prompt += new_prompt[1:] + else: + prompt = new_prompt language = detect(prompt) english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt) pretty_print("Ok! Parameters updated.") @@ -339,11 +346,14 @@ def load_img(path): pygame.display.flip() elif pos[0] > 1500: # Not in the images. if pos[1] < Y/3: - filename = input(to_native("Filename (please provide the latent file, of the format SD*latent*.txt) ?\n")) - status = False - with open(filename, 'r') as f: - latent = f.read() - break + #filename = input(to_native("Filename (please provide the latent file, of the format SD*latent*.txt) ?\n")) + #status = False + #with open(filename, 'r') as f: + # latent = f.read() + #break + pretty_print("Easy! I exit now, you edit the file and you save it.") + pretty_print("Then just relaunch me and provide the text and the image.") + exit() if pos[1] < 2*Y/3: pretty_print("Let us create a meme!") url = 'https://imgflip.com/memegenerator' diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index d8166748f..844401ba2 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -233,8 +233,7 @@ def __call__( good = eval(os.environ["good"]) bad = eval(os.environ["bad"]) print(f"{len(good)} good and {len(bad)} bad") - i_believe_in_evolution = len(good) > 0 and len(bad) > 0 - i_believe_in_evolution = False + i_believe_in_evolution = len(good) > 0 and len(bad) > 200 print(f"I believe in evolution = {i_believe_in_evolution}") if i_believe_in_evolution: from sklearn import tree @@ -371,7 +370,7 @@ def loss(x): latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae - os.environ["latent_sd"] = str(list(latents.flatten().cpu().detach().numpy())) + #os.environ["latent_sd"] = str(list(latents.flatten().cpu().detach().numpy())) latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample From b84f706fac151280b7f927369c6aefd915a35e88 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 29 Sep 2022 16:40:44 +0200 Subject: [PATCH 52/76] fix --- README.md | 5 +- minisd.py | 110 +++++++++++++++++++++++++++-------- pipeline_stable_diffusion.py | 10 ++-- 3 files changed, 95 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index a35ea5c4b..ff557b945 100644 --- a/README.md +++ b/README.md @@ -18,10 +18,12 @@ Ping us at the Nevergrad user group if you need help, I'll do my best. ## Install StableDiffusion as usual, plus a few more stuff. Basically: ``` +brew install wget conda env create -f environment.yaml conda activate ldm # you can change that name in the environment.yaml file... conda install pytorch torchvision -c pytorch pip install transformers diffusers invisible-watermark +pip install -e . pip install pygame pip install einops pip install webbrowser @@ -29,7 +31,8 @@ pip install pyfiglet pip install nevergrad pip install langdetect pip install deep-translator -pip install -e . +pip install git+https://github.com/sberbank-ai/Real-ESRGAN.git +wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P weights ``` ## Hack diffusers (yes I should do that differently... only solution for now). diff --git a/minisd.py b/minisd.py index 878aa3407..64431b16c 100644 --- a/minisd.py +++ b/minisd.py @@ -18,6 +18,7 @@ white = (255, 255, 255) green = (0, 255, 0) +darkgreen = (0, 128, 0) red = (255, 0, 0) blue = (0, 0, 128) black = (0, 0, 0) @@ -33,8 +34,17 @@ num_iterations = 50 gs = 7.5 + + + + + + +all_selected = [] forcedlatents = [] + + pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token="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") pipe = pipe.to(device) @@ -94,6 +104,8 @@ prompt = "A photo of Yoda on the left, a blue octopus on the right, an explosion in the center." prompt = "A bird is on a hippo. They fight a black and red octopus. Jungle in the background." prompt = "A flying white owl above 4 colored pots with fire. The owl has a hat." +prompt = "A flying white owl above 4 colored pots with fire." +prompt = "An armored Mark Zuckerberg fighting off a monster with bloody tentacles in the jungle with a light saber." print(f"The prompt is {prompt}") @@ -120,6 +132,26 @@ def pretty_print(stri): print(f"{to_native('Working on')} {english_prompt}, a.k.a {prompt}.") +def eg(list_of_files): + pretty_print("Should I convert images below to high resolution ?") + print(list_of_files) + answer = input(" [y]es / [n]o ?") + if "y" in answer or "Y" in answer: + model = RealESRGAN(device, scale=4) + model.load_weights('weights/RealESRGAN_x4.pth', download=True) + for f in list_of_files: + import torch + from PIL import Image + import numpy as np + from RealESRGAN import RealESRGAN + + #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + path_to_image = f + image = Image.open(path_to_image).convert('RGB') + sr_image = model.predict(image) + output_filename = "SR" + f + sr_image.save(output_filename) + print(to_native(f"Created the super-resolution file {output_filename}")) import os import pygame @@ -149,6 +181,7 @@ def pretty_print(stri): Y = 900 scrn = pygame.display.set_mode((1700, Y + 100)) font = pygame.font.Font('freesansbold.ttf', 22) +bigfont = pygame.font.Font('freesansbold.ttf', 44) def load_img(path): image = Image.open(path).convert("RGB") @@ -163,30 +196,42 @@ def load_img(path): return 2.*image - 1. if len(image_name) > 0: + pretty_print("Importing an image !") import torchvision #forced_latent = pipe.get_latent(torchvision.io.read_image(image_name).float()) model = pipe.vae - init_image = load_img(image_name).to(device) - init_image = repeat(init_image, '1 ... -> b ...', b=1) - #forced_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) - #forced_latent = model.encode(init_image) - forced_latent = model.encode(init_image.to(device)).latent_dist.sample() - print(forced_latent.shape) - #os.environ["forcedlatent"] = str(list(forced_latent.flatten().cpu().detach().numpy())) - # Forced latent is for after os.environ["forcedlatent"]... + try: + init_image = load_img(image_name).to(device) + except: + pretty_print("Try again!") + image_name = input(to_native("Name of image for starting ? (enter if no start image)")) + pretty_print("Loading failed!!") + + base_init_image = load_img(image_name).to(device) forcedlatents = [] - new_fl = forced_latent.cpu().detach().numpy().flatten() - basic_new_fl = np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) for i in range(llambda): + c = np.exp(np.random.randn() - 2) + init_image_shape = base_init_image.cpu().numpy().shape + if i > 0: + init_image = base_init_image + torch.from_numpy(c * np.random.randn(np.prod(init_image_shape))).reshape(init_image_shape).float().to(device) + else: + init_image = base_init_image + init_image = repeat(init_image, '1 ... -> b ...', b=1) + forced_latent = 6. * model.encode(init_image.to(device)).latent_dist.sample() + new_fl = forced_latent.cpu().detach().numpy().flatten() + basic_new_fl = new_fl #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl #new_fl = forced_latent + (1. / 1.1**(llambda-i)) * torch.from_numpy(np.random.randn(1*4*64*64).reshape(1,4,64,64)).float().to(device) #forcedlatents += [new_fl.cpu().detach().numpy()] if i > 0: - epsilon = 0.3 / 1.1**i + #epsilon = 0.3 / 1.1**i + basic_new_fl = np.sqrt(len(new_fl) / np.sum(new_fl**2)) * basic_new_fl + epsilon = 1.0 / 2**(2 + i / 6) new_fl = epsilon * basic_new_fl + (1 - epsilon) * np.random.randn(1*4*64*64) else: new_fl = basic_new_fl - new_fl = np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) - forcedlatents += [new_fl] + #new_fl = np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) + #forcedlatents += [new_fl] #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl] + forcedlatents += [np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl] #print(f"{i} --> {forcedlatents[i][:10]}") for iteration in range(30): @@ -200,7 +245,9 @@ def load_img(path): os.environ["forcedlatent"] = str(list(forcedlatents[k].flatten())) if k < len(five_best): imp = pygame.transform.scale(pygame.image.load(onlyfiles[k]).convert(), (300, 300)) - shutil.copyfile(onlyfiles[k], to_native("Selected") + onlyfiles[k]) + selected_filename = to_native("Selected") + onlyfiles[k] + shutil.copyfile(onlyfiles[k], selected_filename) + all_selected += [selected_filename] # Using blit to copy content from one surface to other scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) pygame.display.flip() @@ -209,18 +256,18 @@ def load_img(path): break pygame.draw.rect(scrn, black, pygame.Rect(0, Y, 1700, Y+100)) pygame.draw.rect(scrn, black, pygame.Rect(1500, 0, 2000, Y+100)) - text0 = font.render(to_native(f'Please wait !!! {k} / {llambda}'), True, green, blue) + text0 = bigfont.render(to_native(f'Please wait !!! {k} / {llambda}'), True, green, blue) scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/4)) text0 = font.render(to_native(f'Or, if you find one image very cool and want to focus on it only,'), True, green, blue) scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2-Y/8)) - text0 = font.render(to_native(f'then click on it AND KEEP THE MOUSE AT THE SAME POINT until I get the click.),'), True, green, blue) + text0 = font.render(to_native(f'then click on it AND KEEP THE MOUSE AT THE SAME POINT until I get the click.'), True, green, blue) scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2)) - text0 = font.render(to_native(f'for rerunning on a specific image.'), True, green, blue) + text0 = font.render(to_native(f'Then I''ll work on variants of that specific image.'), True, green, blue) scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2+Y/8)) pygame.display.flip() os.environ["earlystop"] = "False" if k > len(five_best) else "True" os.environ["epsilon"] = str(0. if k == len(five_best) else (k - len(five_best)) / llambda) - os.environ["budget"] = str(300 if k > len(five_best) else 2) + os.environ["budget"] = str(np.random.randint(400) if k > len(five_best) else 2) os.environ["skl"] = {0: "nn", 1: "tree", 2: "logit"}[k % 3] #enforcedlatent = os.environ.get("enforcedlatent", "") #if len(enforcedlatent) > 2: @@ -279,7 +326,7 @@ def load_img(path): #scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, 0)) # Button for creating a meme - text2 = font.render(to_native('Create'), True, green, blue) + text2 = font.render(to_native('Stop / High-Resolution / Create '), True, green, blue) text2 = pygame.transform.rotate(text2, 90) scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) text2 = font.render(to_native('a meme'), True, green, blue) @@ -312,6 +359,12 @@ def load_img(path): for i in pygame.event.get(): if i.type == pygame.MOUSEBUTTONUP: print(to_native(".... too early for clicking !!!!")) + + + pretty_print("Please click on your favorite elements!") + print(to_native("You might just click on one image and we will provide variations.")) + print(to_native("Or you can click on the top of an image and the bottom of another one.")) + print(to_native("Click on the << new generation >> when you're done.")) while (status): # iterate over the list of Event objects @@ -355,13 +408,17 @@ def load_img(path): pretty_print("Then just relaunch me and provide the text and the image.") exit() if pos[1] < 2*Y/3: - pretty_print("Let us create a meme!") - url = 'https://imgflip.com/memegenerator' onlyfiles = [f for f in listdir(".") if isfile(join(mypath, f))] onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] print(to_native("Your generated images:")) print(onlyfiles) - webbrowser.open(url) + eg(all_selected + onlyfiles) + pretty_print("Should we create a meme ?") + answer = input(" [y]es or [n]o ?") + if "y" in answer or "Y" in answer: + url = 'https://imgflip.com/memegenerator' + webbrowser.open(url) + pretty_print("Good bye!") exit() status = False break @@ -393,7 +450,7 @@ def load_img(path): for _ in range(123): x = np.random.randint(1500) y = np.random.randint(900) - pygame.draw.circle(scrn, (0, 255, 0), + pygame.draw.circle(scrn, darkgreen, [x, y], 17, 0) pygame.display.update() if len(indices) == 0: @@ -407,8 +464,8 @@ def load_img(path): for u in [u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]]: sauron += latent[u] sauron = (1 / len([u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]])) * sauron - if len(bad) > 200: - bad = bad[(len(bad) - 200):] + if len(bad) > 300: + bad = bad[(len(bad) - 300):] for a in range(llambda): forcedlatent = np.zeros((4, 64, 64)) os.environ["good"] = str(good) @@ -437,6 +494,9 @@ def load_img(path): forcedlatent[k][i][j] = latent[uu][k][i][j] if a % 2 == 0: forcedlatent -= np.random.rand() * sauron + basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent + epsilon = 1.0 / 2**(2 + a / 6) + forcedlatent = epsilon * basic_new_fl + (1 - epsilon) * np.random.randn(1*4*64*64) forcedlatents += [forcedlatent] #for uu in range(len(latent)): # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}") diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index 844401ba2..b4caa2a5a 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -225,7 +225,9 @@ def __call__( print("we get a forcing for the latent z.") speedup = 1 latents = np.array(eval(os.environ["forcedlatent"])).flatten() - latents = np.sqrt(len(latents)) * latents / np.sqrt(np.sum(latents ** 2)) + #latents = latents + np.exp(0.1 * np.random.randn()) * np.random.rand(len(latents)) + #latents = np.sqrt(len(latents) / np.sum(latents ** 2)) * latents + #latents = np.sqrt(len(latents)) * latents / np.sqrt(np.sum(latents ** 2)) print(latents[:10]) print(f"immediately after loading latent ==> {sum(latents.flatten()**2) / len(latents.flatten())}") latents = torch.from_numpy(latents.reshape((1,4,64,64))).float().to(latents_device) @@ -233,7 +235,7 @@ def __call__( good = eval(os.environ["good"]) bad = eval(os.environ["bad"]) print(f"{len(good)} good and {len(bad)} bad") - i_believe_in_evolution = len(good) > 0 and len(bad) > 200 + i_believe_in_evolution = len(good) > 0 and len(bad) > 10 print(f"I believe in evolution = {i_believe_in_evolution}") if i_believe_in_evolution: from sklearn import tree @@ -264,9 +266,9 @@ def loss(x): #return clf.predict_proba([z+epsilon*x])[0][0] - if i_believe_in_evolution: + budget = int(os.environ.get("budget", "300")) + if i_believe_in_evolution and budget > 20: import nevergrad as ng - budget = int(os.environ.get("budget", "300")) #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) optim_class = ng.optimizers.registry[os.environ.get("ngoptim", "DiscreteLenglerOnePlusOne")] From eb779290b658df0b3844dfcc52a12f576fb3d289 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 29 Sep 2022 20:38:48 +0200 Subject: [PATCH 53/76] fix --- README.md | 1 + minisd.py | 146 ++++++++++++++++++++++++++--------- pipeline_stable_diffusion.py | 2 + 3 files changed, 114 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index ff557b945..ab5bcb4f0 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ conda install pytorch torchvision -c pytorch pip install transformers diffusers invisible-watermark pip install -e . pip install pygame +pip install pyttsx3 pip install einops pip install webbrowser pip install pyfiglet diff --git a/minisd.py b/minisd.py index 64431b16c..8ca8d677e 100644 --- a/minisd.py +++ b/minisd.py @@ -36,8 +36,16 @@ +import pyttsx3 +noise = pyttsx3.init() +noise.setProperty("rate", 178) +noise.setProperty('voice', 'mb-us1') +#voice = noise.getProperty('voices') +#for v in voice: +# if v.name == "Kyoko": +# noise.setProperty('voice', v.id) all_selected = [] @@ -116,6 +124,8 @@ print(f"Francais: Proposez un nouveau texte si vous ne voulez pas dessiner << {prompt} >>.\n") +noise.say("Hey!") +noise.runAndWait() user_prompt = input(f"English: Enter a new prompt if you prefer something else than << {prompt} >>.\n") if len(user_prompt) > 2: prompt = user_prompt @@ -135,6 +145,8 @@ def pretty_print(stri): def eg(list_of_files): pretty_print("Should I convert images below to high resolution ?") print(list_of_files) + noise.say("Hey!") + noise.runAndWait() answer = input(" [y]es / [n]o ?") if "y" in answer or "Y" in answer: model = RealESRGAN(device, scale=4) @@ -153,6 +165,19 @@ def eg(list_of_files): sr_image.save(output_filename) print(to_native(f"Created the super-resolution file {output_filename}")) +def stop_all(list_of_files): + print(to_native("Your selected images and the last generation:")) + print(list_of_files) + eg(all_selected + onlyfiles) + pretty_print("Should we create a meme ?") + answer = input(" [y]es or [n]o ?") + if "y" in answer or "Y" in answer: + url = 'https://imgflip.com/memegenerator' + webbrowser.open(url) + pretty_print("Good bye!") + exit() + + import os import pygame from os import listdir @@ -208,14 +233,37 @@ def load_img(path): pretty_print("Loading failed!!") base_init_image = load_img(image_name).to(device) + noise.say("Image loaded") + noise.runAndWait() + print(base_init_image.shape) + print(np.max(base_init_image.cpu().detach().numpy().flatten())) + print(np.min(base_init_image.cpu().detach().numpy().flatten())) forcedlatents = [] + divider = 1.5 for i in range(llambda): + new_base_init_image = base_init_image + if (i % 7) == 1: + new_base_init_image[0,0,:,:] /= divider + if (i % 7) == 2: + new_base_init_image[0,1,:,:] /= divider + if (i % 7) == 3: + new_base_init_image[0,2,:,:] /= divider + if (i % 7) == 4: + new_base_init_image[0,0,:,:] /= divider + new_base_init_image[0,1,:,:] /= divider + if (i % 7) == 5: + new_base_init_image[0,1,:,:] /= divider + new_base_init_image[0,2,:,:] /= divider + if (i % 7) == 6: + new_base_init_image[0,0,:,:] /= divider + new_base_init_image[0,2,:,:] /= divider + c = np.exp(np.random.randn() - 2) init_image_shape = base_init_image.cpu().numpy().shape if i > 0: - init_image = base_init_image + torch.from_numpy(c * np.random.randn(np.prod(init_image_shape))).reshape(init_image_shape).float().to(device) + init_image = new_base_init_image + torch.from_numpy(c * np.random.randn(np.prod(init_image_shape))).reshape(init_image_shape).float().to(device) else: - init_image = base_init_image + init_image = new_base_init_image init_image = repeat(init_image, '1 ... -> b ...', b=1) forced_latent = 6. * model.encode(init_image.to(device)).latent_dist.sample() new_fl = forced_latent.cpu().detach().numpy().flatten() @@ -224,30 +272,29 @@ def load_img(path): #forcedlatents += [new_fl.cpu().detach().numpy()] if i > 0: #epsilon = 0.3 / 1.1**i - basic_new_fl = np.sqrt(len(new_fl) / np.sum(new_fl**2)) * basic_new_fl + #basic_new_fl = np.sqrt(len(new_fl) / np.sum(new_fl**2)) * basic_new_fl epsilon = 1.0 / 2**(2 + i / 6) new_fl = epsilon * basic_new_fl + (1 - epsilon) * np.random.randn(1*4*64*64) else: new_fl = basic_new_fl #new_fl = np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) - #forcedlatents += [new_fl] #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl] - forcedlatents += [np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl] + forcedlatents += [new_fl] #np.clip(new_fl, -3., 3.)] #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl] + #forcedlatents += [np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl] #print(f"{i} --> {forcedlatents[i][:10]}") +# We start the big loop! for iteration in range(30): - #scrn.fill(black) latent = [latent[f] for f in five_best] images = [images[f] for f in five_best] onlyfiles = [onlyfiles[f] for f in five_best] early_stop = [] + noise.say("WAIT!") + noise.runAndWait() for k in range(llambda): if len(forcedlatents) > 0 and k < len(forcedlatents): os.environ["forcedlatent"] = str(list(forcedlatents[k].flatten())) if k < len(five_best): imp = pygame.transform.scale(pygame.image.load(onlyfiles[k]).convert(), (300, 300)) - selected_filename = to_native("Selected") + onlyfiles[k] - shutil.copyfile(onlyfiles[k], selected_filename) - all_selected += [selected_filename] # Using blit to copy content from one surface to other scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) pygame.display.flip() @@ -264,6 +311,18 @@ def load_img(path): scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2)) text0 = font.render(to_native(f'Then I''ll work on variants of that specific image.'), True, green, blue) scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2+Y/8)) + + # Button for early stopping + text2 = font.render(to_native('Click here for stopping, '), True, green, blue) + text2 = pygame.transform.rotate(text2, 90) + scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) + text2 = font.render(to_native('and get the effects,'), True, green, blue) + text2 = pygame.transform.rotate(text2, 90) + scrn.blit(text2, (X*3/4+X/16+X/64 - X/32, Y/3)) + text2 = font.render(to_native('or for creating a meme.'), True, green, blue) + text2 = pygame.transform.rotate(text2, 90) + scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3)) + pygame.display.flip() os.environ["earlystop"] = "False" if k > len(five_best) else "True" os.environ["epsilon"] = str(0. if k == len(five_best) else (k - len(five_best)) / llambda) @@ -283,6 +342,9 @@ def load_img(path): # Using blit to copy content from one surface to other scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) pygame.display.flip() + #noise.say("Dong") + #noise.runAndWait() + print('\a') str_latent = eval((os.environ["latent_sd"])) array_latent = eval(f"np.array(str_latent).reshape(4, 64, 64)") print(f"Debug info: array_latent sumsq/var {sum(array_latent.flatten() ** 2) / len(array_latent.flatten())}") @@ -292,12 +354,20 @@ def load_img(path): # In case of early stopping. for i in pygame.event.get(): if i.type == pygame.MOUSEBUTTONUP: - pos = pygame.mouse.get_pos() + noise.say("Ok I stop") + noise.runAndWait() + pos = pygame.mouse.get_pos() index = 3 * (pos[0] // 300) + (pos[1] // 300) + if pos[0] > X and pos[1] > Y /3 and pos[1] < 2*Y/3: + stop_all(all_selected) + exit() if index <= k: pretty_print(("You clicked for requesting an early stopping.")) early_stop = [pos] break + #early_stop = [(k - 1, .5,.5)] + pretty_print("I do not understand your click.") + pretty_print("So I assume you want the last image...") # Stop the forcing from disk! #os.environ["enforcedlatent"] = "" @@ -310,7 +380,9 @@ def load_img(path): # create the display surface object # of specific dimension..e(X, Y). - + if len(early_stop) == 0: + noise.say("Ok I'm ready!") + noise.runAndWait() # Add rectangles pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 0, X*3/4+X/16+X/32, Y/3), 2) pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, Y/3, X*3/4+X/16+X/32, 2*Y/3), 2) @@ -326,19 +398,19 @@ def load_img(path): #scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, 0)) # Button for creating a meme - text2 = font.render(to_native('Stop / High-Resolution / Create '), True, green, blue) + text2 = font.render(to_native('Click ,'), True, green, blue) text2 = pygame.transform.rotate(text2, 90) - scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) - text2 = font.render(to_native('a meme'), True, green, blue) + scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3+10)) + text2 = font.render(to_native('for finishing with effects.'), True, green, blue) text2 = pygame.transform.rotate(text2, 90) - scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3)) + scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3+10)) # Button for new generation text3 = font.render(to_native(f"I don't want to select images"), True, green, blue) text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3)) + scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3+10)) text3 = font.render(to_native(f"Just rerun."), True, green, blue) text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3)) + scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3+10)) text4 = font.render(to_native(f"Modify parameters or text!"), True, green, blue) scrn.blit(text4, (300, Y + 30)) pygame.display.flip() @@ -394,6 +466,7 @@ def load_img(path): language = detect(prompt) english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt) pretty_print("Ok! Parameters updated.") + pretty_print("==> go back to the window!") text4 = font.render(to_native(f"Ok! parameters changed!"), True, green, blue) scrn.blit(text4, (300, Y + 30)) pygame.display.flip() @@ -408,31 +481,27 @@ def load_img(path): pretty_print("Then just relaunch me and provide the text and the image.") exit() if pos[1] < 2*Y/3: - onlyfiles = [f for f in listdir(".") if isfile(join(mypath, f))] - onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] - print(to_native("Your generated images:")) - print(onlyfiles) - eg(all_selected + onlyfiles) - pretty_print("Should we create a meme ?") - answer = input(" [y]es or [n]o ?") - if "y" in answer or "Y" in answer: - url = 'https://imgflip.com/memegenerator' - webbrowser.open(url) - pretty_print("Good bye!") + #onlyfiles = [f for f in listdir(".") if isfile(join(mypath, f))] + #onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] + stop_all(all_selected + onlyfiles) exit() status = False break index = 3 * (pos[0] // 300) + (pos[1] // 300) + pygame.draw.circle(scrn, red, [pos[0], pos[1]], 3, 0) + selected_filename = to_native("Selected") + onlyfiles[index] + shutil.copyfile(onlyfiles[index], selected_filename) + all_selected += [selected_filename] if index not in five_best and len(five_best) < 5: five_best += [index] indices += [[index, (pos[0] - (pos[0] // 300) * 300) / 300, (pos[1] - (pos[1] // 300) * 300) / 300]] # Update the button for new generation. pygame.draw.rect(scrn, black, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y)) pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y), 2) - text3 = font.render(to_native(f" I have chosen {len(indices)} images:"), True, green, blue) + text3 = font.render(to_native(f" You have chosen {len(indices)} images:"), True, green, blue) text3 = pygame.transform.rotate(text3, 90) scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3)) - text3 = font.render(to_native(f" New generation!"), True, green, blue) + text3 = font.render(to_native(f" Click for new generation!"), True, green, blue) text3 = pygame.transform.rotate(text3, 90) scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3)) pygame.display.flip() @@ -466,6 +535,8 @@ def load_img(path): sauron = (1 / len([u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]])) * sauron if len(bad) > 300: bad = bad[(len(bad) - 300):] + print(to_native(f"{len(indices)} indices are selected.")) + #print(f"indices = {indices}") for a in range(llambda): forcedlatent = np.zeros((4, 64, 64)) os.environ["good"] = str(good) @@ -479,6 +550,10 @@ def load_img(path): y = j / 63 mindistances = 10000000000. for u in range(len(indices)): + #print(a, i, x, j, y, u) + #print(indices[u][1]) + #print(indices[u][2]) + #print(f" {coefficients[u]}* np.linalg.norm({np.array((x, y))}-{np.array((indices[u][1], indices[u][2]))}") distance = coefficients[u] * np.linalg.norm( np.array((x, y)) - np.array((indices[u][1], indices[u][2])) ) if distance < mindistances: mindistances = distance @@ -491,12 +566,13 @@ def load_img(path): assert k < len(latent[uu]), k assert i < len(latent[uu][k]), i assert j < len(latent[uu][k][i]), j - forcedlatent[k][i][j] = latent[uu][k][i][j] - if a % 2 == 0: - forcedlatent -= np.random.rand() * sauron + forcedlatent[k][i][j] = float(latent[uu][k][i][j]) + #if a % 2 == 0: + # forcedlatent -= np.random.rand() * sauron basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent - epsilon = 1.0 / 2**(2 + a / 6) - forcedlatent = epsilon * basic_new_fl + (1 - epsilon) * np.random.randn(1*4*64*64) + epsilon = (a / (llambda - 1)) ** 3 + forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) + forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent forcedlatents += [forcedlatent] #for uu in range(len(latent)): # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}") diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index b4caa2a5a..c274f4f02 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -323,6 +323,8 @@ def loss(x): if latents.shape != latents_intermediate_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_intermediate_shape}") print(f"latent ==> {sum(latents.flatten()**2) / len(latents.flatten())}") + print(f"latent ==> {torch.max(latents)}") + print(f"latent ==> {torch.min(latents)}") os.environ["latent_sd"] = str(list(latents.flatten().cpu().numpy())) for i in [2, 3]: latents = torch.repeat_interleave(latents, repeats=latents_shape[i] // latents_intermediate_shape[i], dim=i) #/ np.sqrt(np.sqrt(latents_shape[i] // latents_intermediate_shape[i])) From 998140011a679e4854a15584394ad1634fbaa452 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 29 Sep 2022 21:52:31 +0200 Subject: [PATCH 54/76] fix --- minisd.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/minisd.py b/minisd.py index 8ca8d677e..e3e832da1 100644 --- a/minisd.py +++ b/minisd.py @@ -229,8 +229,8 @@ def load_img(path): init_image = load_img(image_name).to(device) except: pretty_print("Try again!") - image_name = input(to_native("Name of image for starting ? (enter if no start image)")) pretty_print("Loading failed!!") + image_name = input(to_native("Name of image for starting ? (enter if no start image)")) base_init_image = load_img(image_name).to(device) noise.say("Image loaded") @@ -238,8 +238,20 @@ def load_img(path): print(base_init_image.shape) print(np.max(base_init_image.cpu().detach().numpy().flatten())) print(np.min(base_init_image.cpu().detach().numpy().flatten())) + forcedlatents = [] divider = 1.5 + latent_found = False + try: + latent_file = image_name + ".latent.txt" + print(to_native(f"Trying to load latent variables in {latent_file}.")) + f = open(latent_file, "r") + print(to_native("File opened.")) + latent_str = f.read() + print("Latent string read.") + latent_found = True + except: + print(to_native("No latent file: guessing.")) for i in range(llambda): new_base_init_image = base_init_image if (i % 7) == 1: @@ -265,8 +277,11 @@ def load_img(path): else: init_image = new_base_init_image init_image = repeat(init_image, '1 ... -> b ...', b=1) - forced_latent = 6. * model.encode(init_image.to(device)).latent_dist.sample() - new_fl = forced_latent.cpu().detach().numpy().flatten() + if latent_found: + new_fl = np.asarray(eval(latent_str)) + else: + forced_latent = 6. * model.encode(init_image.to(device)).latent_dist.sample() + new_fl = forced_latent.cpu().detach().numpy().flatten() basic_new_fl = new_fl #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl #new_fl = forced_latent + (1. / 1.1**(llambda-i)) * torch.from_numpy(np.random.randn(1*4*64*64).reshape(1,4,64,64)).float().to(device) #forcedlatents += [new_fl.cpu().detach().numpy()] @@ -349,8 +364,8 @@ def load_img(path): array_latent = eval(f"np.array(str_latent).reshape(4, 64, 64)") print(f"Debug info: array_latent sumsq/var {sum(array_latent.flatten() ** 2) / len(array_latent.flatten())}") latent += [array_latent] - with open(f"SD_{prompt.replace(' ','_')}_latent_{sentinel}_{k}.txt", 'w') as f: - f.write(f"{latent}") + with open(filename + ".latent.txt", 'w') as f: + f.write(f"{str_latent}") # In case of early stopping. for i in pygame.event.get(): if i.type == pygame.MOUSEBUTTONUP: @@ -570,7 +585,7 @@ def load_img(path): #if a % 2 == 0: # forcedlatent -= np.random.rand() * sauron basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent - epsilon = (a / (llambda - 1)) ** 3 + epsilon = (a / (llambda - 1)) ** 6 forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent forcedlatents += [forcedlatent] From f2fa901289984fef1370a85bab82fa55522a9d0f Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Sat, 1 Oct 2022 17:23:03 +0200 Subject: [PATCH 55/76] fix --- README.md | 1 + minisd.py | 295 ++++++++++++++++++++++++----------- pipeline_stable_diffusion.py | 13 +- 3 files changed, 212 insertions(+), 97 deletions(-) diff --git a/README.md b/README.md index ab5bcb4f0..7525218ff 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ conda install pytorch torchvision -c pytorch pip install transformers diffusers invisible-watermark pip install -e . pip install pygame +pip install joblib pip install pyttsx3 pip install einops pip install webbrowser diff --git a/minisd.py b/minisd.py index e3e832da1..c618dd913 100644 --- a/minisd.py +++ b/minisd.py @@ -11,6 +11,8 @@ import webbrowser from deep_translator import GoogleTranslator from langdetect import detect +from joblib import Parallel, delayed + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" model_id = "CompVis/stable-diffusion-v1-4" #device = "cuda" @@ -28,6 +30,7 @@ os.environ["decay"] = "0." os.environ["ngoptim"] = "DiscreteLenglerOnePlusOne" os.environ["forcedlatent"] = "" +latent_forcing = "" #os.environ["enforcedlatent"] = "" os.environ["good"] = "[]" os.environ["bad"] = "[]" @@ -39,7 +42,7 @@ import pyttsx3 noise = pyttsx3.init() -noise.setProperty("rate", 178) +noise.setProperty("rate", 240) noise.setProperty('voice', 'mb-us1') #voice = noise.getProperty('voices') @@ -49,6 +52,8 @@ all_selected = [] +all_selected_latent = [] +final_selection = [] forcedlatents = [] @@ -113,7 +118,9 @@ prompt = "A bird is on a hippo. They fight a black and red octopus. Jungle in the background." prompt = "A flying white owl above 4 colored pots with fire. The owl has a hat." prompt = "A flying white owl above 4 colored pots with fire." +prompt = "Yann LeCun rides a dragon which spits fire on a cherry on a cake." prompt = "An armored Mark Zuckerberg fighting off a monster with bloody tentacles in the jungle with a light saber." +prompt = "Cute woman, portrait, photo, red hair, green eyes, smiling." print(f"The prompt is {prompt}") @@ -142,33 +149,115 @@ def pretty_print(stri): print(f"{to_native('Working on')} {english_prompt}, a.k.a {prompt}.") +def latent_to_image(latent): + os.environ["forcedlatent"] = str(list(latent.flatten())) #str(list(forcedlatents[k].flatten())) + with autocast("cuda"): + image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0] + return image + +import torch +from PIL import Image +from RealESRGAN import RealESRGAN + +sr_device = torch.device('cpu') #device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu') +esrmodel = RealESRGAN(sr_device, scale=4) +esrmodel.load_weights('weights/RealESRGAN_x4.pth', download=True) +esrmodel2 = RealESRGAN(sr_device, scale=2) +esrmodel2.load_weights('weights/RealESRGAN_x2.pth', download=True) + +def singleeg(path_to_image): + image = Image.open(path_to_image).convert('RGB') + sr_device = device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Type before SR = {type(image)}") + sr_image = esrmodel.predict(image) + print(f"Type after SR = {type(sr_image)}") + output_filename = path_to_image + ".SR.png" + sr_image.save(output_filename) + return output_filename +def singleeg2(path_to_image): + image = Image.open(path_to_image).convert('RGB') + sr_device = device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Type before SR = {type(image)}") + sr_image = esrmodel2.predict(image) + print(f"Type after SR = {type(sr_image)}") + output_filename = path_to_image + ".SR.png" + sr_image.save(output_filename) + return output_filename + + def eg(list_of_files): pretty_print("Should I convert images below to high resolution ?") print(list_of_files) - noise.say("Hey!") + noise.say("Go to the text window!") noise.runAndWait() answer = input(" [y]es / [n]o ?") if "y" in answer or "Y" in answer: - model = RealESRGAN(device, scale=4) - model.load_weights('weights/RealESRGAN_x4.pth', download=True) - for f in list_of_files: - import torch - from PIL import Image - import numpy as np - from RealESRGAN import RealESRGAN - - #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - path_to_image = f - image = Image.open(path_to_image).convert('RGB') - sr_image = model.predict(image) - output_filename = "SR" + f - sr_image.save(output_filename) + #images = Parallel(n_jobs=12)(delayed(singleeg)(image) for image in list_of_files) + #print(to_native(f"Created the super-resolution files {images}")) + for path_to_image in list_of_files: + output_filename = singleeg(path_to_image) print(to_native(f"Created the super-resolution file {output_filename}")) -def stop_all(list_of_files): +def stop_all(list_of_files, list_of_latent, last_list_of_latent): print(to_native("Your selected images and the last generation:")) print(list_of_files) - eg(all_selected + onlyfiles) + eg(list_of_files) + pretty_print("Should we create animations ?") + answer = input(" [y]es or [n]o or [j]ust the selection on the last panel ?") + if "y" in answer or "Y" in answer or "j" in answer or "J" in answer: + if "j" in answer or "J" in answer: + list_of_latent = last_list_of_latent + pretty_print("Let us create animations!") + assert len(list_of_files) == len(list_of_latent) + for c in [0.5, 0.25, 0.125, 0.0625, 0.05, 0.04,0.03125]: + for idx in range(len(list_of_files)): + images = [] + l = list_of_latent[idx].reshape(1,4,64,64) + l = np.sqrt(len(l.flatten()) / np.sum(l**2)) * l + l1 = l + c * np.random.randn(len(l.flatten())).reshape(1,4,64,64) + l1 = np.sqrt(len(l1.flatten()) / np.sum(l1**2)) * l1 + l2 = l + c * np.random.randn(len(l.flatten())).reshape(1,4,64,64) + l2 = np.sqrt(len(l2.flatten()) / np.sum(l2**2)) * l2 + num_animation_steps = 13 + index = 0 + for u in np.linspace(0., 2*3.14159 * (1-1/30), 30): + cc = np.cos(u) + ss = np.sin(u*2) + index += 1 + image = latent_to_image(l + cc * (l1 - l) + ss * (l2 - l)) + image_name = f"imgA{index}.png" + image.save(image_name) + images += [image_name] + +# for u in np.linspace(0., 1., num_animation_steps): +# index += 1 +# image = latent_to_image(u*l1 + (1-u)*l) +# image_name = f"imgA{index}.png" +# image.save(image_name) +# images += [image_name] +# for u in np.linspace(0., 1., num_animation_steps): +# index += 1 +# image = latent_to_image(u*l2 + (1-u)*l1) +# image_name = f"imgB{index}.png" +# image.save(image_name) +# images += [image_name] +# for u in np.linspace(0., 1.,num_animation_steps): +# index += 1 +# image = latent_to_image(u*l + (1-u)*l2) +# image_name = f"imgC{index}.png" +# image.save(image_name) +# images += [image_name] + print(to_native(f"Base images created for perturbation={c} and file {list_of_files[idx]}")) + #images = Parallel(n_jobs=8)(delayed(process)(i) for i in range(10)) + images = Parallel(n_jobs=16)(delayed(singleeg2)(image) for image in images) + + frames = [Image.open(image) for image in images] + frame_one = frames[0] + gif_name = list_of_files[idx] + "_" + str(c) + ".gif" + frame_one.save(gif_name, format="GIF", append_images=frames, + save_all=True, duration=100, loop=0) + webbrowser.open(os.environ["PWD"] + "/" + gif_name) + pretty_print("Should we create a meme ?") answer = input(" [y]es or [n]o ?") if "y" in answer or "Y" in answer: @@ -254,31 +343,33 @@ def load_img(path): print(to_native("No latent file: guessing.")) for i in range(llambda): new_base_init_image = base_init_image - if (i % 7) == 1: - new_base_init_image[0,0,:,:] /= divider - if (i % 7) == 2: - new_base_init_image[0,1,:,:] /= divider - if (i % 7) == 3: - new_base_init_image[0,2,:,:] /= divider - if (i % 7) == 4: - new_base_init_image[0,0,:,:] /= divider - new_base_init_image[0,1,:,:] /= divider - if (i % 7) == 5: - new_base_init_image[0,1,:,:] /= divider - new_base_init_image[0,2,:,:] /= divider - if (i % 7) == 6: - new_base_init_image[0,0,:,:] /= divider - new_base_init_image[0,2,:,:] /= divider + if not latent_found: # In case of latent vars we need less exploration. + if (i % 7) == 1: + new_base_init_image[0,0,:,:] /= divider + if (i % 7) == 2: + new_base_init_image[0,1,:,:] /= divider + if (i % 7) == 3: + new_base_init_image[0,2,:,:] /= divider + if (i % 7) == 4: + new_base_init_image[0,0,:,:] /= divider + new_base_init_image[0,1,:,:] /= divider + if (i % 7) == 5: + new_base_init_image[0,1,:,:] /= divider + new_base_init_image[0,2,:,:] /= divider + if (i % 7) == 6: + new_base_init_image[0,0,:,:] /= divider + new_base_init_image[0,2,:,:] /= divider c = np.exp(np.random.randn() - 2) init_image_shape = base_init_image.cpu().numpy().shape - if i > 0: + if i > 0 and not latent_found: init_image = new_base_init_image + torch.from_numpy(c * np.random.randn(np.prod(init_image_shape))).reshape(init_image_shape).float().to(device) else: init_image = new_base_init_image init_image = repeat(init_image, '1 ... -> b ...', b=1) if latent_found: new_fl = np.asarray(eval(latent_str)) + assert len(new_fl) > 1 else: forced_latent = 6. * model.encode(init_image.to(device)).latent_dist.sample() new_fl = forced_latent.cpu().detach().numpy().flatten() @@ -287,9 +378,9 @@ def load_img(path): #forcedlatents += [new_fl.cpu().detach().numpy()] if i > 0: #epsilon = 0.3 / 1.1**i - #basic_new_fl = np.sqrt(len(new_fl) / np.sum(new_fl**2)) * basic_new_fl - epsilon = 1.0 / 2**(2 + i / 6) - new_fl = epsilon * basic_new_fl + (1 - epsilon) * np.random.randn(1*4*64*64) + #basic_new_fl = np.sqrt(len(new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl + epsilon = (i-1)/(llambda-1) #1.0 / 2**(2 + (llambda - i) / 6) + new_fl = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(1*4*64*64) else: new_fl = basic_new_fl #new_fl = np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) @@ -297,7 +388,7 @@ def load_img(path): #forcedlatents += [np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl] #print(f"{i} --> {forcedlatents[i][:10]}") -# We start the big loop! +# We start the big time consuming loop! for iteration in range(30): latent = [latent[f] for f in five_best] images = [images[f] for f in five_best] @@ -305,36 +396,40 @@ def load_img(path): early_stop = [] noise.say("WAIT!") noise.runAndWait() + final_selection = [] for k in range(llambda): + if len(early_stop) > 0: + break + max_created_index = k if len(forcedlatents) > 0 and k < len(forcedlatents): - os.environ["forcedlatent"] = str(list(forcedlatents[k].flatten())) + #os.environ["forcedlatent"] = str(list(forcedlatents[k].flatten())) + latent_forcing = str(list(forcedlatents[k].flatten())) + print(f"We play with {latent_forcing[:20]}") if k < len(five_best): imp = pygame.transform.scale(pygame.image.load(onlyfiles[k]).convert(), (300, 300)) # Using blit to copy content from one surface to other scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) pygame.display.flip() continue - if len(early_stop) > 0: - break pygame.draw.rect(scrn, black, pygame.Rect(0, Y, 1700, Y+100)) pygame.draw.rect(scrn, black, pygame.Rect(1500, 0, 2000, Y+100)) text0 = bigfont.render(to_native(f'Please wait !!! {k} / {llambda}'), True, green, blue) scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/4)) - text0 = font.render(to_native(f'Or, if you find one image very cool and want to focus on it only,'), True, green, blue) + text0 = font.render(to_native(f'Or, for an early stopping,'), True, green, blue) scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2-Y/8)) - text0 = font.render(to_native(f'then click on it AND KEEP THE MOUSE AT THE SAME POINT until I get the click.'), True, green, blue) + text0 = font.render(to_native(f'click and WAIT a bit'), True, green, blue) scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2)) - text0 = font.render(to_native(f'Then I''ll work on variants of that specific image.'), True, green, blue) + text0 = font.render(to_native(f'... ... ... '), True, green, blue) scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2+Y/8)) # Button for early stopping - text2 = font.render(to_native('Click here for stopping, '), True, green, blue) + text2 = font.render(to_native(f'{len(all_selected)} chosen images! '), True, green, blue) text2 = pygame.transform.rotate(text2, 90) scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) - text2 = font.render(to_native('and get the effects,'), True, green, blue) + text2 = font.render(to_native('Click for stopping,'), True, green, blue) text2 = pygame.transform.rotate(text2, 90) scrn.blit(text2, (X*3/4+X/16+X/64 - X/32, Y/3)) - text2 = font.render(to_native('or for creating a meme.'), True, green, blue) + text2 = font.render(to_native('and get the effects.'), True, green, blue) text2 = pygame.transform.rotate(text2, 90) scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3)) @@ -347,9 +442,10 @@ def load_img(path): #if len(enforcedlatent) > 2: # os.environ["forcedlatent"] = enforcedlatent # os.environ["enforcedlatent"] = "" - with autocast("cuda"): - image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0] - images += [image] + #with autocast("cuda"): + # image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0] + image = latent_to_image(np.asarray(latent_forcing)) #eval(os.environ["forcedlatent"]))) + images += [image] filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{iteration:05d}_{k:05d}.png" image.save(filename) onlyfiles += [filename] @@ -374,15 +470,14 @@ def load_img(path): pos = pygame.mouse.get_pos() index = 3 * (pos[0] // 300) + (pos[1] // 300) if pos[0] > X and pos[1] > Y /3 and pos[1] < 2*Y/3: - stop_all(all_selected) + stop_all(all_selected, all_selected_latent, final_selection) exit() if index <= k: pretty_print(("You clicked for requesting an early stopping.")) early_stop = [pos] break - #early_stop = [(k - 1, .5,.5)] - pretty_print("I do not understand your click.") - pretty_print("So I assume you want the last image...") + early_stop = [(1,1)] + satus = False # Stop the forcing from disk! #os.environ["enforcedlatent"] = "" @@ -395,9 +490,15 @@ def load_img(path): # create the display surface object # of specific dimension..e(X, Y). - if len(early_stop) == 0: - noise.say("Ok I'm ready!") - noise.runAndWait() + noise.say("Ok I'm ready! Choose") + noise.runAndWait() + pretty_print("Please choose your images.") + text0 = bigfont.render(to_native(f'Choose your favorite images !!!========='), True, green, blue) + scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/4)) + text0 = font.render(to_native(f'=================================='), True, green, blue) + scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2-Y/8)) + text0 = font.render(to_native(f'=================================='), True, green, blue) + scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2)) # Add rectangles pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 0, X*3/4+X/16+X/32, Y/3), 2) pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, Y/3, X*3/4+X/16+X/32, 2*Y/3), 2) @@ -430,12 +531,12 @@ def load_img(path): scrn.blit(text4, (300, Y + 30)) pygame.display.flip() - if len(early_stop) == 0: - for idx in range(llambda): - # set the pygame window name - pygame.display.set_caption(prompt) - imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300)) - scrn.blit(imp, (300 * (idx // 3), 300 * (idx % 3))) + for idx in range(max_created_index + 1): + # set the pygame window name + pygame.display.set_caption(prompt) + print(to_native(f"Pasting image {onlyfiles[idx]}...")) + imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300)) + scrn.blit(imp, (300 * (idx // 3), 300 * (idx % 3))) # paint screen one time pygame.display.flip() @@ -456,9 +557,9 @@ def load_img(path): # iterate over the list of Event objects # that was returned by pygame.event.get() method. - for i in early_stop + pygame.event.get(): - if hasattr(i, "type") and i.type == pygame.MOUSEBUTTONUP or len(early_stop) > 0: - pos = early_stop[0] if len(early_stop) > 0 else pygame.mouse.get_pos() + for i in pygame.event.get(): + if hasattr(i, "type") and i.type == pygame.MOUSEBUTTONUP: + pos = pygame.mouse.get_pos() pretty_print(f"Detected! Click at {pos}") if pos[1] > Y: pretty_print("Let us update parameters!") @@ -498,36 +599,46 @@ def load_img(path): if pos[1] < 2*Y/3: #onlyfiles = [f for f in listdir(".") if isfile(join(mypath, f))] #onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] - stop_all(all_selected + onlyfiles) + assert len(onlyfiles) == len(latent) + assert len(all_selected) == len(all_selected_latent) + stop_all(all_selected, all_selected_latent, final_selection) # + onlyfiles, all_selected_latent + latent) exit() status = False break index = 3 * (pos[0] // 300) + (pos[1] // 300) - pygame.draw.circle(scrn, red, [pos[0], pos[1]], 3, 0) - selected_filename = to_native("Selected") + onlyfiles[index] - shutil.copyfile(onlyfiles[index], selected_filename) - all_selected += [selected_filename] - if index not in five_best and len(five_best) < 5: - five_best += [index] - indices += [[index, (pos[0] - (pos[0] // 300) * 300) / 300, (pos[1] - (pos[1] // 300) * 300) / 300]] - # Update the button for new generation. - pygame.draw.rect(scrn, black, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y)) - pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y), 2) - text3 = font.render(to_native(f" You have chosen {len(indices)} images:"), True, green, blue) - text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3)) - text3 = font.render(to_native(f" Click for new generation!"), True, green, blue) - text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3)) - pygame.display.flip() - #text3Rect = text3.get_rect() - #text3Rect.center = (750+750*3/4, 1000) - good += [list(latent[index].flatten())] + pygame.draw.circle(scrn, red, [pos[0], pos[1]], 13, 0) + if index <= max_created_index: + selected_filename = to_native("Selected") + onlyfiles[index] + shutil.copyfile(onlyfiles[index], selected_filename) + assert len(onlyfiles) == len(latent), f"{len(onlyfiles)} != {len(latent)}" + all_selected += [selected_filename] + all_selected_latent += [latent[index]] + final_selection += [latent[index]] + text2 = font.render(to_native(f'{len(all_selected)} chosen images! '), True, green, blue) + text2 = pygame.transform.rotate(text2, 90) + scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) + if index not in five_best and len(five_best) < 5: + five_best += [index] + indices += [[index, (pos[0] - (pos[0] // 300) * 300) / 300, (pos[1] - (pos[1] // 300) * 300) / 300]] + # Update the button for new generation. + pygame.draw.rect(scrn, black, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y)) + pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y), 2) + text3 = font.render(to_native(f" You have chosen {len(indices)} images:"), True, green, blue) + text3 = pygame.transform.rotate(text3, 90) + scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3)) + text3 = font.render(to_native(f" Click for new generation!"), True, green, blue) + text3 = pygame.transform.rotate(text3, 90) + scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3)) + pygame.display.flip() + #text3Rect = text3.get_rect() + #text3Rect.center = (750+750*3/4, 1000) + good += [list(latent[index].flatten())] + else: + noise.say("Bad click! Click on image.") + noise.runAndWait() + pretty_print("Bad click! Click on image.") - # if event object type is QUIT - # then quitting the pygame - # and program both. - if len(early_stop) > 0 or i.type == pygame.QUIT: + if i.type == pygame.QUIT: status = False # Covering old images with full circles. @@ -585,7 +696,7 @@ def load_img(path): #if a % 2 == 0: # forcedlatent -= np.random.rand() * sauron basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent - epsilon = (a / (llambda - 1)) ** 6 + epsilon = 0.3 * (((a - len(good)) / (llambda - len(good) - 1)) ** 6) forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent forcedlatents += [forcedlatent] diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py index c274f4f02..8e3199b44 100644 --- a/pipeline_stable_diffusion.py +++ b/pipeline_stable_diffusion.py @@ -221,17 +221,20 @@ def __call__( generator=generator, device=latents_device, ) - if len(os.environ["forcedlatent"]) > 0: - print("we get a forcing for the latent z.") + if len(os.environ["forcedlatent"]) > 10: + stri = os.environ["forcedlatent"] + print(f"we get a forcing for the latent z: {stri[:20]}.") + if len(eval(stri)) == 1: + stri = str(eval(stri)[0]) speedup = 1 - latents = np.array(eval(os.environ["forcedlatent"])).flatten() + latents = np.array(list(eval(stri))).flatten() #latents = latents + np.exp(0.1 * np.random.randn()) * np.random.rand(len(latents)) #latents = np.sqrt(len(latents) / np.sum(latents ** 2)) * latents #latents = np.sqrt(len(latents)) * latents / np.sqrt(np.sum(latents ** 2)) - print(latents[:10]) + print(f"As an array, this is {latents[:10]}") print(f"immediately after loading latent ==> {sum(latents.flatten()**2) / len(latents.flatten())}") latents = torch.from_numpy(latents.reshape((1,4,64,64))).float().to(latents_device) - os.environ["forcedlatent"] = "" + os.environ["forcedlatent"] = "" good = eval(os.environ["good"]) bad = eval(os.environ["bad"]) print(f"{len(good)} good and {len(bad)} bad") From ce78634c1ea812f267776aa289a6df937f8a05e7 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Sun, 2 Oct 2022 08:51:42 +0200 Subject: [PATCH 56/76] fix --- minisd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/minisd.py b/minisd.py index c618dd913..5010a160c 100644 --- a/minisd.py +++ b/minisd.py @@ -7,7 +7,7 @@ from PIL import Image from einops import rearrange, repeat from torch import autocast -from diffusers import StableDiffusionPipeline +from local_diffusers import StableDiffusionPipeline import webbrowser from deep_translator import GoogleTranslator from langdetect import detect From 347e19e25fe5b7c5e7a3e22d5ef592955f7b6e8f Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Sun, 2 Oct 2022 08:55:36 +0200 Subject: [PATCH 57/76] fix --- local_diffusers/__init__.py | 60 + local_diffusers/commands/__init__.py | 27 + local_diffusers/commands/diffusers_cli.py | 41 + local_diffusers/commands/env.py | 70 + local_diffusers/configuration_utils.py | 403 +++++ local_diffusers/dependency_versions_check.py | 47 + local_diffusers/dependency_versions_table.py | 26 + local_diffusers/dynamic_modules_utils.py | 335 ++++ local_diffusers/hub_utils.py | 197 +++ local_diffusers/modeling_utils.py | 542 ++++++ local_diffusers/models/__init__.py | 17 + local_diffusers/models/attention.py | 333 ++++ local_diffusers/models/embeddings.py | 115 ++ local_diffusers/models/resnet.py | 483 ++++++ local_diffusers/models/unet_2d.py | 246 +++ local_diffusers/models/unet_2d_condition.py | 270 +++ local_diffusers/models/unet_blocks.py | 1481 +++++++++++++++++ local_diffusers/models/vae.py | 581 +++++++ local_diffusers/onnx_utils.py | 189 +++ local_diffusers/optimization.py | 275 +++ local_diffusers/pipeline_utils.py | 417 +++++ local_diffusers/pipelines/__init__.py | 19 + local_diffusers/pipelines/ddim/__init__.py | 2 + .../pipelines/ddim/pipeline_ddim.py | 117 ++ local_diffusers/pipelines/ddpm/__init__.py | 2 + .../pipelines/ddpm/pipeline_ddpm.py | 106 ++ .../pipelines/latent_diffusion/__init__.py | 6 + .../pipeline_latent_diffusion.py | 705 ++++++++ .../latent_diffusion_uncond/__init__.py | 2 + .../pipeline_latent_diffusion_uncond.py | 108 ++ local_diffusers/pipelines/pndm/__init__.py | 2 + .../pipelines/pndm/pipeline_pndm.py | 111 ++ .../pipelines/score_sde_ve/__init__.py | 2 + .../score_sde_ve/pipeline_score_sde_ve.py | 101 ++ .../pipelines/stable_diffusion/__init__.py | 37 + .../pipeline_stable_diffusion.py | 397 +++++ .../pipeline_stable_diffusion_img2img.py | 291 ++++ .../pipeline_stable_diffusion_inpaint.py | 309 ++++ .../pipeline_stable_diffusion_onnx.py | 165 ++ .../stable_diffusion/safety_checker.py | 106 ++ .../stochastic_karras_ve/__init__.py | 2 + .../pipeline_stochastic_karras_ve.py | 129 ++ local_diffusers/schedulers/__init__.py | 28 + local_diffusers/schedulers/scheduling_ddim.py | 261 +++ local_diffusers/schedulers/scheduling_ddpm.py | 264 +++ .../schedulers/scheduling_karras_ve.py | 208 +++ .../schedulers/scheduling_lms_discrete.py | 193 +++ local_diffusers/schedulers/scheduling_pndm.py | 378 +++++ .../schedulers/scheduling_sde_ve.py | 283 ++++ .../schedulers/scheduling_sde_vp.py | 81 + .../schedulers/scheduling_utils.py | 125 ++ local_diffusers/testing_utils.py | 61 + local_diffusers/training_utils.py | 125 ++ local_diffusers/utils/__init__.py | 53 + local_diffusers/utils/dummy_scipy_objects.py | 11 + ...rmers_and_inflect_and_unidecode_objects.py | 10 + .../dummy_transformers_and_onnx_objects.py | 11 + .../utils/dummy_transformers_objects.py | 32 + local_diffusers/utils/import_utils.py | 274 +++ local_diffusers/utils/logging.py | 344 ++++ local_diffusers/utils/outputs.py | 109 ++ 61 files changed, 11725 insertions(+) create mode 100644 local_diffusers/__init__.py create mode 100644 local_diffusers/commands/__init__.py create mode 100644 local_diffusers/commands/diffusers_cli.py create mode 100644 local_diffusers/commands/env.py create mode 100644 local_diffusers/configuration_utils.py create mode 100644 local_diffusers/dependency_versions_check.py create mode 100644 local_diffusers/dependency_versions_table.py create mode 100644 local_diffusers/dynamic_modules_utils.py create mode 100644 local_diffusers/hub_utils.py create mode 100644 local_diffusers/modeling_utils.py create mode 100644 local_diffusers/models/__init__.py create mode 100644 local_diffusers/models/attention.py create mode 100644 local_diffusers/models/embeddings.py create mode 100644 local_diffusers/models/resnet.py create mode 100644 local_diffusers/models/unet_2d.py create mode 100644 local_diffusers/models/unet_2d_condition.py create mode 100644 local_diffusers/models/unet_blocks.py create mode 100644 local_diffusers/models/vae.py create mode 100644 local_diffusers/onnx_utils.py create mode 100644 local_diffusers/optimization.py create mode 100644 local_diffusers/pipeline_utils.py create mode 100644 local_diffusers/pipelines/__init__.py create mode 100644 local_diffusers/pipelines/ddim/__init__.py create mode 100644 local_diffusers/pipelines/ddim/pipeline_ddim.py create mode 100644 local_diffusers/pipelines/ddpm/__init__.py create mode 100644 local_diffusers/pipelines/ddpm/pipeline_ddpm.py create mode 100644 local_diffusers/pipelines/latent_diffusion/__init__.py create mode 100644 local_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py create mode 100644 local_diffusers/pipelines/latent_diffusion_uncond/__init__.py create mode 100644 local_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py create mode 100644 local_diffusers/pipelines/pndm/__init__.py create mode 100644 local_diffusers/pipelines/pndm/pipeline_pndm.py create mode 100644 local_diffusers/pipelines/score_sde_ve/__init__.py create mode 100644 local_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py create mode 100644 local_diffusers/pipelines/stable_diffusion/__init__.py create mode 100644 local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py create mode 100644 local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py create mode 100644 local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py create mode 100644 local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py create mode 100644 local_diffusers/pipelines/stable_diffusion/safety_checker.py create mode 100644 local_diffusers/pipelines/stochastic_karras_ve/__init__.py create mode 100644 local_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py create mode 100644 local_diffusers/schedulers/__init__.py create mode 100644 local_diffusers/schedulers/scheduling_ddim.py create mode 100644 local_diffusers/schedulers/scheduling_ddpm.py create mode 100644 local_diffusers/schedulers/scheduling_karras_ve.py create mode 100644 local_diffusers/schedulers/scheduling_lms_discrete.py create mode 100644 local_diffusers/schedulers/scheduling_pndm.py create mode 100644 local_diffusers/schedulers/scheduling_sde_ve.py create mode 100644 local_diffusers/schedulers/scheduling_sde_vp.py create mode 100644 local_diffusers/schedulers/scheduling_utils.py create mode 100644 local_diffusers/testing_utils.py create mode 100644 local_diffusers/training_utils.py create mode 100644 local_diffusers/utils/__init__.py create mode 100644 local_diffusers/utils/dummy_scipy_objects.py create mode 100644 local_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py create mode 100644 local_diffusers/utils/dummy_transformers_and_onnx_objects.py create mode 100644 local_diffusers/utils/dummy_transformers_objects.py create mode 100644 local_diffusers/utils/import_utils.py create mode 100644 local_diffusers/utils/logging.py create mode 100644 local_diffusers/utils/outputs.py diff --git a/local_diffusers/__init__.py b/local_diffusers/__init__.py new file mode 100644 index 000000000..bf2f183c9 --- /dev/null +++ b/local_diffusers/__init__.py @@ -0,0 +1,60 @@ +from .utils import ( + is_inflect_available, + is_onnx_available, + is_scipy_available, + is_transformers_available, + is_unidecode_available, +) + + +__version__ = "0.3.0" + +from .configuration_utils import ConfigMixin +from .modeling_utils import ModelMixin +from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel +from .onnx_utils import OnnxRuntimeModel +from .optimization import ( + get_constant_schedule, + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, + get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, + get_scheduler, +) +from .pipeline_utils import DiffusionPipeline +from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline +from .schedulers import ( + DDIMScheduler, + DDPMScheduler, + KarrasVeScheduler, + PNDMScheduler, + SchedulerMixin, + ScoreSdeVeScheduler, +) +from .utils import logging + + +if is_scipy_available(): + from .schedulers import LMSDiscreteScheduler +else: + from .utils.dummy_scipy_objects import * # noqa F403 + +from .training_utils import EMAModel + + +if is_transformers_available(): + from .pipelines import ( + LDMTextToImagePipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + ) +else: + from .utils.dummy_transformers_objects import * # noqa F403 + + +if is_transformers_available() and is_onnx_available(): + from .pipelines import StableDiffusionOnnxPipeline +else: + from .utils.dummy_transformers_and_onnx_objects import * # noqa F403 diff --git a/local_diffusers/commands/__init__.py b/local_diffusers/commands/__init__.py new file mode 100644 index 000000000..902bd46ce --- /dev/null +++ b/local_diffusers/commands/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from argparse import ArgumentParser + + +class BaseDiffusersCLICommand(ABC): + @staticmethod + @abstractmethod + def register_subcommand(parser: ArgumentParser): + raise NotImplementedError() + + @abstractmethod + def run(self): + raise NotImplementedError() diff --git a/local_diffusers/commands/diffusers_cli.py b/local_diffusers/commands/diffusers_cli.py new file mode 100644 index 000000000..30084e55b --- /dev/null +++ b/local_diffusers/commands/diffusers_cli.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser + +from .env import EnvironmentCommand + + +def main(): + parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []") + commands_parser = parser.add_subparsers(help="diffusers-cli command helpers") + + # Register commands + EnvironmentCommand.register_subcommand(commands_parser) + + # Let's go + args = parser.parse_args() + + if not hasattr(args, "func"): + parser.print_help() + exit(1) + + # Run + service = args.func(args) + service.run() + + +if __name__ == "__main__": + main() diff --git a/local_diffusers/commands/env.py b/local_diffusers/commands/env.py new file mode 100644 index 000000000..81a878bff --- /dev/null +++ b/local_diffusers/commands/env.py @@ -0,0 +1,70 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform +from argparse import ArgumentParser + +import huggingface_hub + +from .. import __version__ as version +from ..utils import is_torch_available, is_transformers_available +from . import BaseDiffusersCLICommand + + +def info_command_factory(_): + return EnvironmentCommand() + + +class EnvironmentCommand(BaseDiffusersCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + download_parser = parser.add_parser("env") + download_parser.set_defaults(func=info_command_factory) + + def run(self): + hub_version = huggingface_hub.__version__ + + pt_version = "not installed" + pt_cuda_available = "NA" + if is_torch_available(): + import torch + + pt_version = torch.__version__ + pt_cuda_available = torch.cuda.is_available() + + transformers_version = "not installed" + if is_transformers_available: + import transformers + + transformers_version = transformers.__version__ + + info = { + "`diffusers` version": version, + "Platform": platform.platform(), + "Python version": platform.python_version(), + "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", + "Huggingface_hub version": hub_version, + "Transformers version": transformers_version, + "Using GPU in script?": "", + "Using distributed or parallel set-up in script?": "", + } + + print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") + print(self.format_dict(info)) + + return info + + @staticmethod + def format_dict(d): + return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" diff --git a/local_diffusers/configuration_utils.py b/local_diffusers/configuration_utils.py new file mode 100644 index 000000000..fbe75f3f1 --- /dev/null +++ b/local_diffusers/configuration_utils.py @@ -0,0 +1,403 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" ConfigMixinuration base class and utilities.""" +import functools +import inspect +import json +import os +import re +from collections import OrderedDict +from typing import Any, Dict, Tuple, Union + +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from requests import HTTPError + +from . import __version__ +from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging + + +logger = logging.get_logger(__name__) + +_re_configuration_file = re.compile(r"config\.(.*)\.json") + + +class ConfigMixin: + r""" + Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all + methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with + - [`~ConfigMixin.from_config`] + - [`~ConfigMixin.save_config`] + + Class attributes: + - **config_name** (`str`) -- A filename under which the config should stored when calling + [`~ConfigMixin.save_config`] (should be overriden by parent class). + - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be + overriden by parent class). + """ + config_name = None + ignore_for_config = [] + + def register_to_config(self, **kwargs): + if self.config_name is None: + raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") + kwargs["_class_name"] = self.__class__.__name__ + kwargs["_diffusers_version"] = __version__ + + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + previous_dict = dict(self._internal_dict) + internal_dict = {**self._internal_dict, **kwargs} + logger.debug(f"Updating config from {previous_dict} to {internal_dict}") + + self._internal_dict = FrozenDict(internal_dict) + + def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~ConfigMixin.from_config`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + # If we save using the predefined names, we can load using `from_config` + output_config_file = os.path.join(save_directory, self.config_name) + + self.to_json_file(output_config_file) + logger.info(f"ConfigMixinuration saved in {output_config_file}") + + @classmethod + def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): + r""" + Instantiate a Python class from a pre-defined JSON-file. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + Passing `use_auth_token=True`` is required when you want to use a private model. + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ + config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + + init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) + + model = cls(**init_dict) + + if return_unused_kwargs: + return model, unused_kwargs + else: + return model + + @classmethod + def get_config_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + + user_agent = {"file_type": "config"} + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + + if cls.config_name is None: + raise ValueError( + "`self.config_name` is not defined. Note that one should not load a config from " + "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" + ) + + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): + # Load from a PyTorch checkpoint + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + ): + config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + else: + raise EnvironmentError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + config_file = hf_hub_download( + pretrained_model_name_or_path, + filename=cls.config_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" + " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" + " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli" + " login` and pass `use_auth_token=True`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for" + " this model name. Check the model page at" + f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}." + ) + except HTTPError as err: + raise EnvironmentError( + "There was a specific connection error when trying to load" + f" {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to" + " run the library in offline mode at" + " 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a {cls.config_name} file" + ) + + try: + # Load config dict + config_dict = cls._dict_from_json_file(config_file) + except (json.JSONDecodeError, UnicodeDecodeError): + raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") + + return config_dict + + @classmethod + def extract_init_dict(cls, config_dict, **kwargs): + expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) + expected_keys.remove("self") + # remove general kwargs if present in dict + if "kwargs" in expected_keys: + expected_keys.remove("kwargs") + # remove keys to be ignored + if len(cls.ignore_for_config) > 0: + expected_keys = expected_keys - set(cls.ignore_for_config) + init_dict = {} + for key in expected_keys: + if key in kwargs: + # overwrite key + init_dict[key] = kwargs.pop(key) + elif key in config_dict: + # use value from config dict + init_dict[key] = config_dict.pop(key) + + unused_kwargs = config_dict.update(kwargs) + + passed_keys = set(init_dict.keys()) + if len(expected_keys - passed_keys) > 0: + logger.warning( + f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." + ) + + return init_dict, unused_kwargs + + @classmethod + def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @property + def config(self) -> Dict[str, Any]: + return self._internal_dict + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this configuration instance in JSON format. + """ + config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __setattr__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setattr__(name, value) + + def __setitem__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setitem__(name, value) + + +def register_to_config(init): + r""" + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable + + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self, *args, **kwargs): + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + init(self, *args, **init_kwargs) + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + ignore = getattr(self, "ignore_for_config", []) + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + getattr(self, "register_to_config")(**new_kwargs) + + return inner_init diff --git a/local_diffusers/dependency_versions_check.py b/local_diffusers/dependency_versions_check.py new file mode 100644 index 000000000..bbf863222 --- /dev/null +++ b/local_diffusers/dependency_versions_check.py @@ -0,0 +1,47 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +from .dependency_versions_table import deps +from .utils.versions import require_version, require_version_core + + +# define which module versions we always want to check at run time +# (usually the ones defined in `install_requires` in setup.py) +# +# order specific notes: +# - tqdm must be checked before tokenizers + +pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split() +if sys.version_info < (3, 7): + pkgs_to_check_at_runtime.append("dataclasses") +if sys.version_info < (3, 8): + pkgs_to_check_at_runtime.append("importlib_metadata") + +for pkg in pkgs_to_check_at_runtime: + if pkg in deps: + if pkg == "tokenizers": + # must be loaded here, or else tqdm check may fail + from .utils import is_tokenizers_available + + if not is_tokenizers_available(): + continue # not required, check version only if installed + + require_version_core(deps[pkg]) + else: + raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") + + +def dep_version_check(pkg, hint=None): + require_version(deps[pkg], hint) diff --git a/local_diffusers/dependency_versions_table.py b/local_diffusers/dependency_versions_table.py new file mode 100644 index 000000000..74c5331e5 --- /dev/null +++ b/local_diffusers/dependency_versions_table.py @@ -0,0 +1,26 @@ +# THIS FILE HAS BEEN AUTOGENERATED. To update: +# 1. modify the `_deps` dict in setup.py +# 2. run `make deps_table_update`` +deps = { + "Pillow": "Pillow", + "accelerate": "accelerate>=0.11.0", + "black": "black==22.3", + "datasets": "datasets", + "filelock": "filelock", + "flake8": "flake8>=3.8.3", + "hf-doc-builder": "hf-doc-builder>=0.3.0", + "huggingface-hub": "huggingface-hub>=0.8.1", + "importlib_metadata": "importlib_metadata", + "isort": "isort>=5.5.4", + "modelcards": "modelcards==0.1.4", + "numpy": "numpy", + "pytest": "pytest", + "pytest-timeout": "pytest-timeout", + "pytest-xdist": "pytest-xdist", + "scipy": "scipy", + "regex": "regex!=2019.12.17", + "requests": "requests", + "tensorboard": "tensorboard", + "torch": "torch>=1.4", + "transformers": "transformers>=4.21.0", +} diff --git a/local_diffusers/dynamic_modules_utils.py b/local_diffusers/dynamic_modules_utils.py new file mode 100644 index 000000000..0ebf916e7 --- /dev/null +++ b/local_diffusers/dynamic_modules_utils.py @@ -0,0 +1,335 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities to dynamically load objects from the Hub.""" + +import importlib +import os +import re +import shutil +import sys +from pathlib import Path +from typing import Dict, Optional, Union + +from huggingface_hub import cached_download + +from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def init_hf_modules(): + """ + Creates the cache directory for modules with an init, and adds it to the Python path. + """ + # This function has already been executed if HF_MODULES_CACHE already is in the Python path. + if HF_MODULES_CACHE in sys.path: + return + + sys.path.append(HF_MODULES_CACHE) + os.makedirs(HF_MODULES_CACHE, exist_ok=True) + init_path = Path(HF_MODULES_CACHE) / "__init__.py" + if not init_path.exists(): + init_path.touch() + + +def create_dynamic_module(name: Union[str, os.PathLike]): + """ + Creates a dynamic module in the cache directory for modules. + """ + init_hf_modules() + dynamic_module_path = Path(HF_MODULES_CACHE) / name + # If the parent module does not exist yet, recursively create it. + if not dynamic_module_path.parent.exists(): + create_dynamic_module(dynamic_module_path.parent) + os.makedirs(dynamic_module_path, exist_ok=True) + init_path = dynamic_module_path / "__init__.py" + if not init_path.exists(): + init_path.touch() + + +def get_relative_imports(module_file): + """ + Get the list of modules that are relatively imported in a module file. + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + """ + with open(module_file, "r", encoding="utf-8") as f: + content = f.read() + + # Imports of the form `import .xxx` + relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from .xxx import yyy` + relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) + # Unique-ify + return list(set(relative_imports)) + + +def get_relative_import_files(module_file): + """ + Get the list of all files that are needed for a given module. Note that this function recurses through the relative + imports (if a imports b and b imports c, it will return module files for b and c). + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + """ + no_change = False + files_to_check = [module_file] + all_relative_imports = [] + + # Let's recurse through all relative imports + while not no_change: + new_imports = [] + for f in files_to_check: + new_imports.extend(get_relative_imports(f)) + + module_path = Path(module_file).parent + new_import_files = [str(module_path / m) for m in new_imports] + new_import_files = [f for f in new_import_files if f not in all_relative_imports] + files_to_check = [f"{f}.py" for f in new_import_files] + + no_change = len(new_import_files) == 0 + all_relative_imports.extend(files_to_check) + + return all_relative_imports + + +def check_imports(filename): + """ + Check if the current Python environment contains all the libraries that are imported in a file. + """ + with open(filename, "r", encoding="utf-8") as f: + content = f.read() + + # Imports of the form `import xxx` + imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from xxx import yyy` + imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) + # Only keep the top-level module + imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] + + # Unique-ify and test we got them all + imports = list(set(imports)) + missing_packages = [] + for imp in imports: + try: + importlib.import_module(imp) + except ImportError: + missing_packages.append(imp) + + if len(missing_packages) > 0: + raise ImportError( + "This modeling file requires the following packages that were not found in your environment: " + f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" + ) + + return get_relative_imports(filename) + + +def get_class_in_module(class_name, module_path): + """ + Import a module on the cache directory for modules and extract a class from it. + """ + module_path = module_path.replace(os.path.sep, ".") + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def get_cached_module_file( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, +): + """ + Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached + Transformers module. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like `dbmdz/bert-base-german-cased`. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + module_file (`str`): + The name of the module file containing the class to look for. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + Passing `use_auth_token=True` is required when you want to use a private model. + + + + Returns: + `str`: The path to the module inside the cache. + """ + # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) + submodule = "local" + + if os.path.isfile(module_file_or_url): + resolved_module_file = module_file_or_url + else: + try: + # Load from URL or cache if already cached + resolved_module_file = cached_download( + module_file_or_url, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + ) + + except EnvironmentError: + logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") + raise + + # Check we have all the requirements in our environment + modules_needed = check_imports(resolved_module_file) + + # Now we move the module inside our cached dynamic modules. + full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule + create_dynamic_module(full_submodule) + submodule_path = Path(HF_MODULES_CACHE) / full_submodule + # We always copy local files (we could hash the file to see if there was a change, and give them the name of + # that hash, to only copy when there is a modification but it seems overkill for now). + # The only reason we do the copy is to avoid putting too many folders in sys.path. + shutil.copy(resolved_module_file, submodule_path / module_file) + for module_needed in modules_needed: + module_needed = f"{module_needed}.py" + shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) + return os.path.join(full_submodule, module_file) + + +def get_class_from_dynamic_module( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + class_name: str, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + **kwargs, +): + """ + Extracts a class from a module file, present in the local folder or repository of a model. + + + + Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should + therefore only be called on trusted repos. + + + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like `dbmdz/bert-base-german-cased`. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + module_file (`str`): + The name of the module file containing the class to look for. + class_name (`str`): + The name of the class to import in the module. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + Passing `use_auth_token=True` is required when you want to use a private model. + + + + Returns: + `type`: The class, dynamically imported from the module. + + Examples: + + ```python + # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this + # module. + cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") + ```""" + # And lastly we get the class inside our newly created module + final_module = get_cached_module_file( + pretrained_model_name_or_path, + module_file, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + use_auth_token=use_auth_token, + revision=revision, + local_files_only=local_files_only, + ) + return get_class_in_module(class_name, final_module.replace(".py", "")) diff --git a/local_diffusers/hub_utils.py b/local_diffusers/hub_utils.py new file mode 100644 index 000000000..c07329e36 --- /dev/null +++ b/local_diffusers/hub_utils.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil +from pathlib import Path +from typing import Optional + +from huggingface_hub import HfFolder, Repository, whoami + +from .pipeline_utils import DiffusionPipeline +from .utils import is_modelcards_available, logging + + +if is_modelcards_available(): + from modelcards import CardData, ModelCard + + +logger = logging.get_logger(__name__) + + +MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md" + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def init_git_repo(args, at_init: bool = False): + """ + Args: + Initializes a git repo in `args.hub_model_id`. + at_init (`bool`, *optional*, defaults to `False`): + Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True` + and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out. + """ + if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: + return + hub_token = args.hub_token if hasattr(args, "hub_token") else None + use_auth_token = True if hub_token is None else hub_token + if not hasattr(args, "hub_model_id") or args.hub_model_id is None: + repo_name = Path(args.output_dir).absolute().name + else: + repo_name = args.hub_model_id + if "/" not in repo_name: + repo_name = get_full_repo_name(repo_name, token=hub_token) + + try: + repo = Repository( + args.output_dir, + clone_from=repo_name, + use_auth_token=use_auth_token, + private=args.hub_private_repo, + ) + except EnvironmentError: + if args.overwrite_output_dir and at_init: + # Try again after wiping output_dir + shutil.rmtree(args.output_dir) + repo = Repository( + args.output_dir, + clone_from=repo_name, + use_auth_token=use_auth_token, + ) + else: + raise + + repo.git_pull() + + # By default, ignore the checkpoint folders + if not os.path.exists(os.path.join(args.output_dir, ".gitignore")): + with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: + writer.writelines(["checkpoint-*/"]) + + return repo + + +def push_to_hub( + args, + pipeline: DiffusionPipeline, + repo: Repository, + commit_message: Optional[str] = "End of training", + blocking: bool = True, + **kwargs, +) -> str: + """ + Parameters: + Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. + commit_message (`str`, *optional*, defaults to `"End of training"`): + Message to commit while pushing. + blocking (`bool`, *optional*, defaults to `True`): + Whether the function should return only when the `git push` has finished. + kwargs: + Additional keyword arguments passed along to [`create_model_card`]. + Returns: + The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the + commit and an object to track the progress of the commit if `blocking=True` + """ + + if not hasattr(args, "hub_model_id") or args.hub_model_id is None: + model_name = Path(args.output_dir).name + else: + model_name = args.hub_model_id.split("/")[-1] + + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving pipeline checkpoint to {output_dir}") + pipeline.save_pretrained(output_dir) + + # Only push from one node. + if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: + return + + # Cancel any async push in progress if blocking=True. The commits will all be pushed together. + if ( + blocking + and len(repo.command_queue) > 0 + and repo.command_queue[-1] is not None + and not repo.command_queue[-1].is_done + ): + repo.command_queue[-1]._process.kill() + + git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True) + # push separately the model card to be independent from the rest of the model + create_model_card(args, model_name=model_name) + try: + repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True) + except EnvironmentError as exc: + logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") + + return git_head_commit_url + + +def create_model_card(args, model_name): + if not is_modelcards_available: + raise ValueError( + "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can" + " install the package with `pip install modelcards`." + ) + + if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: + return + + hub_token = args.hub_token if hasattr(args, "hub_token") else None + repo_name = get_full_repo_name(model_name, token=hub_token) + + model_card = ModelCard.from_template( + card_data=CardData( # Card metadata object that will be converted to YAML block + language="en", + license="apache-2.0", + library_name="diffusers", + tags=[], + datasets=args.dataset_name, + metrics=[], + ), + template_path=MODEL_CARD_TEMPLATE_PATH, + model_name=model_name, + repo_name=repo_name, + dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None, + learning_rate=args.learning_rate, + train_batch_size=args.train_batch_size, + eval_batch_size=args.eval_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps + if hasattr(args, "gradient_accumulation_steps") + else None, + adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None, + adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None, + adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None, + adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None, + lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None, + lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None, + ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None, + ema_power=args.ema_power if hasattr(args, "ema_power") else None, + ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None, + mixed_precision=args.mixed_precision, + ) + + card_path = os.path.join(args.output_dir, "README.md") + model_card.save(card_path) diff --git a/local_diffusers/modeling_utils.py b/local_diffusers/modeling_utils.py new file mode 100644 index 000000000..fb613614a --- /dev/null +++ b/local_diffusers/modeling_utils.py @@ -0,0 +1,542 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch import Tensor, device + +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from requests import HTTPError + +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging + + +WEIGHTS_NAME = "diffusion_pytorch_model.bin" + + +logger = logging.get_logger(__name__) + + +def get_parameter_device(parameter: torch.nn.Module): + try: + return next(parameter.parameters()).device + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: torch.nn.Module): + try: + return next(parameter.parameters()).dtype + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +def load_state_dict(checkpoint_file: Union[str, os.PathLike]): + """ + Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. + """ + try: + return torch.load(checkpoint_file, map_location="cpu") + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' " + f"at '{checkpoint_file}'. " + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." + ) + + +def _load_state_dict_into_model(model_to_load, state_dict): + # Convert old format to new format if needed from a PyTorch state_dict + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: torch.nn.Module, prefix=""): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model_to_load) + + return error_msgs + + +class ModelMixin(torch.nn.Module): + r""" + Base class for all models. + + [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading + and saving models. + + - **config_name** ([`str`]) -- A filename under which the model should be stored when calling + [`~modeling_utils.ModelMixin.save_pretrained`]. + """ + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + + def __init__(self): + super().__init__() + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Callable = torch.save, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~modeling_utils.ModelMixin.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = self + + # Attach architecture to the config + # Save the config + if is_main_process: + model_to_save.save_config(save_directory) + + # Save the model + state_dict = model_to_save.state_dict() + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process: + os.remove(full_filename) + + # Save the model + save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME)) + + logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + Passing `use_auth_token=True`` is required when you want to use a private model. + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + from_auto_class = kwargs.pop("_from_auto", False) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + + user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + model, unused_kwargs = cls.from_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + **kwargs, + ) + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # Load model + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + # Load from a PyTorch checkpoint + model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) + ): + model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) + else: + raise EnvironmentError( + f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " + "login` and pass `use_auth_token=True`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}." + ) + except HTTPError as err: + raise EnvironmentError( + "There was a specific connection error when trying to load" + f" {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {WEIGHTS_NAME} or" + " \nCheckout your internet connection or see how to run the library in" + " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {WEIGHTS_NAME}" + ) + + # restore default dtype + state_dict = load_state_dict(model_file) + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + if output_loading_info: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = [k for k in state_dict.keys()] + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + + @property + def device(self) -> device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: + """ + Get number of (optionally, trainable or non-embeddings) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters + + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embeddings parameters + + Returns: + `int`: The number of parameters. + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.named_modules() + if isinstance(module_type, torch.nn.Embedding) + ] + non_embedding_parameters = [ + parameter for name, parameter in self.named_parameters() if name not in embedding_param_names + ] + return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) + else: + return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + + +def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: + """ + Recursively unwraps a model from potential containers (as used in distributed training). + + Args: + model (`torch.nn.Module`): The model to unwrap. + """ + # since there could be multiple levels of wrapping, unwrap recursively + if hasattr(model, "module"): + return unwrap_model(model.module) + else: + return model diff --git a/local_diffusers/models/__init__.py b/local_diffusers/models/__init__.py new file mode 100644 index 000000000..e0ac5c8d5 --- /dev/null +++ b/local_diffusers/models/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .unet_2d import UNet2DModel +from .unet_2d_condition import UNet2DConditionModel +from .vae import AutoencoderKL, VQModel diff --git a/local_diffusers/models/attention.py b/local_diffusers/models/attention.py new file mode 100644 index 000000000..de9c92691 --- /dev/null +++ b/local_diffusers/models/attention.py @@ -0,0 +1,333 @@ +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Uses three q, k, v linear layers to compute attention. + + Parameters: + channels (:obj:`int`): The number of channels in the input and output. + num_head_channels (:obj:`int`, *optional*): + The number of channels in each head. If None, then `num_heads` = 1. + num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + """ + + def __init__( + self, + channels: int, + num_head_channels: Optional[int] = None, + num_groups: int = 32, + rescale_output_factor: float = 1.0, + eps: float = 1e-5, + ): + super().__init__() + self.channels = channels + + self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) + + # define q,k,v as linear layers + self.query = nn.Linear(channels, channels) + self.key = nn.Linear(channels, channels) + self.value = nn.Linear(channels, channels) + + self.rescale_output_factor = rescale_output_factor + self.proj_attn = nn.Linear(channels, channels, 1) + + def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward(self, hidden_states): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + # transpose + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + # get scores + scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) + + attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) + + # compute attention output + hidden_states = torch.matmul(attention_probs, value_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image. + + Parameters: + in_channels (:obj:`int`): The number of channels in the input and output. + n_heads (:obj:`int`): The number of heads to use for multi-head attention. + d_head (:obj:`int`): The number of channels in each head. + depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. + context_dim (:obj:`int`, *optional*): The number of context dimensions to use. + """ + + def __init__( + self, + in_channels: int, + n_heads: int, + d_head: int, + depth: int = 1, + dropout: float = 0.0, + context_dim: Optional[int] = None, + ): + super().__init__() + self.n_heads = n_heads + self.d_head = d_head + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth) + ] + ) + + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def _set_attention_slice(self, slice_size): + for block in self.transformer_blocks: + block._set_attention_slice(slice_size) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) + for block in self.transformer_blocks: + x = block(x, context=context) + x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) + x = self.proj_out(x) + return x + x_in + + +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (:obj:`int`): The number of channels in the input and output. + n_heads (:obj:`int`): The number of heads to use for multi-head attention. + d_head (:obj:`int`): The number of channels in each head. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. + gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. + checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. + """ + + def __init__( + self, + dim: int, + n_heads: int, + d_head: int, + dropout=0.0, + context_dim: Optional[int] = None, + gated_ff: bool = True, + checkpoint: bool = True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def _set_attention_slice(self, slice_size): + self.attn1._slice_size = slice_size + self.attn2._slice_size = slice_size + + def forward(self, x, context=None): + x = x.contiguous() if x.device.type == "mps" else x + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class CrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (:obj:`int`): The number of channels in the query. + context_dim (:obj:`int`, *optional*): + The number of channels in the context. If not given, defaults to `query_dim`. + heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + def __init__( + self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0 + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = context_dim if context_dim is not None else query_dim + + self.scale = dim_head**-0.5 + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self._slice_size = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward(self, x, context=None, mask=None): + batch_size, sequence_length, dim = x.shape + + q = self.to_q(x) + context = context if context is not None else x + k = self.to_k(context) + v = self.to_v(context) + + q = self.reshape_heads_to_batch_dim(q) + k = self.reshape_heads_to_batch_dim(k) + v = self.reshape_heads_to_batch_dim(v) + + # TODO(PVP) - mask is currently never used. Remember to re-implement when used + + # attention, what we cannot get enough of + hidden_states = self._attention(q, k, v, sequence_length, dim) + + return self.to_out(hidden_states) + + def _attention(self, query, key, value, sequence_length, dim): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + attn_slice = ( + torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale + ) + attn_slice = attn_slice.softmax(dim=-1) + attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (:obj:`int`): The number of channels in the input. + dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + def __init__( + self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0 + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + project_in = GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (:obj:`int`): The number of channels in the input. + dim_out (:obj:`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) diff --git a/local_diffusers/models/embeddings.py b/local_diffusers/models/embeddings.py new file mode 100644 index 000000000..86ac074c1 --- /dev/null +++ b/local_diffusers/models/embeddings.py @@ -0,0 +1,115 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import numpy as np +import torch +from torch import nn + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent).to(device=timesteps.device) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): + super().__init__() + + self.linear_1 = nn.Linear(channel, time_embed_dim) + self.act = None + if act_fn == "silu": + self.act = nn.SiLU() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + + def forward(self, sample): + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__(self, embedding_size: int = 256, scale: float = 1.0): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + # to delete later + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + self.weight = self.W + + def forward(self, x): + x = torch.log(x) + x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out diff --git a/local_diffusers/models/resnet.py b/local_diffusers/models/resnet.py new file mode 100644 index 000000000..27fae24f7 --- /dev/null +++ b/local_diffusers/models/resnet.py @@ -0,0 +1,483 @@ +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Upsample2D(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + x = self.conv(x) + else: + x = self.Conv2d_0(x) + + return x + + +class Downsample2D(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + + assert x.shape[1] == self.channels + x = self.conv(x) + + return x + + +class FirUpsample2D(nn.Module): + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.use_conv = use_conv + self.fir_kernel = fir_kernel + self.out_channels = out_channels + + def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): + """Fused `upsample_2d()` followed by `Conv2d()`. + + Args: + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary: + order. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + weight: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as + `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + + # Setup filter kernel. + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = np.asarray(kernel, dtype=np.float32) + if kernel.ndim == 1: + kernel = np.outer(kernel, kernel) + kernel /= np.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + + if self.use_conv: + convH = weight.shape[2] + convW = weight.shape[3] + inC = weight.shape[1] + + p = (kernel.shape[0] - factor) - (convW - 1) + + stride = (factor, factor) + # Determine data dimensions. + stride = [1, 1, factor, factor] + output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW) + output_padding = ( + output_shape[0] - (x.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (x.shape[3] - 1) * stride[1] - convW, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + inC = weight.shape[1] + num_groups = x.shape[1] // inC + + # Transpose weights. + weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) + weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) + weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) + + x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0) + + x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) + else: + p = kernel.shape[0] - factor + x = upfirdn2d_native( + x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) + ) + + return x + + def forward(self, x): + if self.use_conv: + height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel) + height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2) + + return height + + +class FirDownsample2D(nn.Module): + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.fir_kernel = fir_kernel + self.use_conv = use_conv + self.out_channels = out_channels + + def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): + """Fused `Conv2d()` followed by `downsample_2d()`. + + Args: + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary: + order. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH, + filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // + numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * + factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: + Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same + datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = np.asarray(kernel, dtype=np.float32) + if kernel.ndim == 1: + kernel = np.outer(kernel, kernel) + kernel /= np.sum(kernel) + + kernel = kernel * gain + + if self.use_conv: + _, _, convH, convW = weight.shape + p = (kernel.shape[0] - factor) + (convW - 1) + s = [factor, factor] + x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2)) + x = F.conv2d(x, weight, stride=s, padding=0) + else: + p = kernel.shape[0] - factor + x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + + return x + + def forward(self, x): + if self.use_conv: + x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) + x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2) + + return x + + +class ResnetBlock2D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + kernel=None, + output_scale_factor=1.0, + use_nin_shortcut=None, + up=False, + down=False, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + else: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") + + self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut + + self.conv_shortcut = None + if self.use_nin_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + hidden_states = x + + # make sure hidden states is in float32 + # when running in half-precision + hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + x = self.upsample(x) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + x = self.downsample(x) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + + # make sure hidden states is in float32 + # when running in half-precision + hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + x = self.conv_shortcut(x) + + out = (x + hidden_states) / self.output_scale_factor + + return out + + +class Mish(torch.nn.Module): + def forward(self, x): + return x * torch.tanh(torch.nn.functional.softplus(x)) + + +def upsample_2d(x, kernel=None, factor=2, gain=1): + r"""Upsample2D a batch of 2D images with the given filter. + + Args: + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given + filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a: + multiple of the upsampling factor. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = np.asarray(kernel, dtype=np.float32) + if kernel.ndim == 1: + kernel = np.outer(kernel, kernel) + kernel /= np.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + p = kernel.shape[0] - factor + return upfirdn2d_native( + x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) + ) + + +def downsample_2d(x, kernel=None, factor=2, gain=1): + r"""Downsample2D a batch of 2D images with the given filter. + + Args: + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the + given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the + specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its + shape is a multiple of the downsampling factor. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = np.asarray(kernel, dtype=np.float32) + if kernel.ndim == 1: + kernel = np.outer(kernel, kernel) + kernel /= np.sum(kernel) + + kernel = kernel * gain + p = kernel.shape[0] - factor + return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + + +def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): + up_x = up_y = up + down_x = down_y = down + pad_x0 = pad_y0 = pad[0] + pad_x1 = pad_y1 = pad[1] + + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + + # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535 + if input.device.type == "mps": + out = out.to("cpu") + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out.to(input.device) # Move back to mps if necessary + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/local_diffusers/models/unet_2d.py b/local_diffusers/models/unet_2d.py new file mode 100644 index 000000000..c3ab621a2 --- /dev/null +++ b/local_diffusers/models/unet_2d.py @@ -0,0 +1,246 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..utils import BaseOutput +from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block + + +@dataclass +class UNet2DOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states output. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet2DModel(ModelMixin, ConfigMixin): + r""" + UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): + Input sample size. + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. + freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to : + obj:`False`): Whether to flip sin to cos for fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block + types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(224, 448, 672, 896)`): Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. + mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. + downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. + norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization. + norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization. + """ + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 3, + out_channels: int = 3, + center_input_sample: bool = False, + time_embedding_type: str = "positional", + freq_shift: int = 0, + flip_sin_to_cos: bool = True, + down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), + up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), + block_out_channels: Tuple[int] = (224, 448, 672, 896), + layers_per_block: int = 2, + mid_block_scale_factor: float = 1, + downsample_padding: int = 1, + act_fn: str = "silu", + attention_head_dim: int = 8, + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + attn_num_head_channels=attention_head_dim, + downsample_padding=downsample_padding, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + attn_num_head_channels=attention_head_dim, + resnet_groups=norm_num_groups, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + attn_num_head_channels=attention_head_dim, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + return_dict: bool = True, + ) -> Union[UNet2DOutput, Tuple]: + """r + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) + + t_emb = self.time_proj(timesteps) + emb = self.time_embedding(t_emb) + + # 2. pre-process + skip_sample = sample + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "skip_conv"): + sample, res_samples, skip_sample = downsample_block( + hidden_states=sample, temb=emb, skip_sample=skip_sample + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb) + + # 5. up + skip_sample = None + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "skip_conv"): + sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) + else: + sample = upsample_block(sample, res_samples, emb) + + # 6. post-process + # make sure hidden states is in float32 + # when running in half-precision + sample = self.conv_norm_out(sample.float()).type(sample.dtype) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if skip_sample is not None: + sample += skip_sample + + if self.config.time_embedding_type == "fourier": + timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) + sample = sample / timesteps + + if not return_dict: + return (sample,) + + return UNet2DOutput(sample=sample) diff --git a/local_diffusers/models/unet_2d_condition.py b/local_diffusers/models/unet_2d_condition.py new file mode 100644 index 000000000..92caaca92 --- /dev/null +++ b/local_diffusers/models/unet_2d_condition.py @@ -0,0 +1,270 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..utils import BaseOutput +from .embeddings import TimestepEmbedding, Timesteps +from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet2DConditionModel(ModelMixin, ConfigMixin): + r""" + UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + sample_size (`int`, *optional*): The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: int = 8, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + downsample_padding=downsample_padding, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + resnet_groups=norm_num_groups, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.config.attention_head_dim % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.config.attention_head_dim}" + ) + if slice_size is not None and slice_size > self.config.attention_head_dim: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.config.attention_head_dim}" + ) + + for block in self.down_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_attention_slice(slice_size) + + self.mid_block.set_attention_slice(slice_size) + + for block in self.up_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_attention_slice(slice_size) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + """r + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps.to(dtype=torch.float32) + timesteps = timesteps[None].to(device=sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: + sample, res_samples = downsample_block( + hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # 5. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples) + + # 6. post-process + # make sure hidden states is in float32 + # when running in half-precision + sample = self.conv_norm_out(sample.float()).type(sample.dtype) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/local_diffusers/models/unet_blocks.py b/local_diffusers/models/unet_blocks.py new file mode 100644 index 000000000..9e0621653 --- /dev/null +++ b/local_diffusers/models/unet_blocks.py @@ -0,0 +1,1481 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import numpy as np + +# limitations under the License. +import torch +from torch import nn + +from .attention import AttentionBlock, SpatialTransformer +from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + cross_attention_dim=None, + downsample_padding=None, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + ) + elif down_block_type == "AttnDownBlock2D": + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + ) + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + cross_attention_dim=None, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + elif up_block_type == "AttnUpBlock2D": + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attn_num_head_channels=attn_num_head_channels, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attn_num_head_channels=attn_num_head_channels, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + **kwargs, + ): + super().__init__() + + self.attention_type = attention_type + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + AttentionBlock( + in_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.attention_type == "default": + hidden_states = attn(hidden_states) + else: + hidden_states = attn(hidden_states, encoder_states) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + cross_attention_dim=1280, + **kwargs, + ): + super().__init__() + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + SpatialTransformer( + in_channels, + attn_num_head_channels, + in_channels // attn_num_head_channels, + depth=1, + context_dim=cross_attention_dim, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + if slice_size is not None and slice_size > self.attn_num_head_channels: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class AttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + SpatialTransformer( + out_channels, + attn_num_head_channels, + out_channels // attn_num_head_channels, + depth=1, + context_dim=cross_attention_dim, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + if slice_size is not None and slice_size > self.attn_num_head_channels: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=encoder_hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnDownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + attentions = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + num_groups=resnet_groups, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnSkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=np.sqrt(2.0), + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + self.attention_type = attention_type + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_nin_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class SkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor=np.sqrt(2.0), + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_nin_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class AttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_type="default", + attn_num_head_channels=1, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + for resnet, attn in zip(self.resnets, self.attentions): + + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + SpatialTransformer( + out_channels, + attn_num_head_channels, + out_channels // attn_num_head_channels, + depth=1, + context_dim=cross_attention_dim, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + if slice_size is not None and slice_size > self.attn_num_head_channels: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None): + for resnet, attn in zip(self.resnets, self.attentions): + + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=encoder_hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + for resnet in self.resnets: + + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnUpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + num_groups=resnet_groups, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnSkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=np.sqrt(2.0), + upsample_padding=1, + add_upsample=True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + self.attention_type = attention_type + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(resnet_in_channels + res_skip_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_nin_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + hidden_states = self.attentions[0](hidden_states) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class SkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor=np.sqrt(2.0), + add_upsample=True, + upsample_padding=1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min((resnet_in_channels + res_skip_channels) // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_nin_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample diff --git a/local_diffusers/models/vae.py b/local_diffusers/models/vae.py new file mode 100644 index 000000000..82748cb5b --- /dev/null +++ b/local_diffusers/models/vae.py @@ -0,0 +1,581 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..utils import BaseOutput +from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block + + +@dataclass +class DecoderOutput(BaseOutput): + """ + Output of decoding method. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Decoded output sample of the model. Output of the last layer of the model. + """ + + sample: torch.FloatTensor + + +@dataclass +class VQEncoderOutput(BaseOutput): + """ + Output of VQModel encoding method. + + Args: + latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Encoded output sample of the model. Output of the last layer of the model. + """ + + latents: torch.FloatTensor + + +@dataclass +class AutoencoderKLOutput(BaseOutput): + """ + Output of AutoencoderKL encoding method. + + Args: + latent_dist (`DiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. + `DiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent_dist: "DiagonalGaussianDistribution" + + +class Encoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + act_fn="silu", + double_z=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + attn_num_head_channels=None, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=32, + temb_channels=None, + ) + + # out + num_groups_out = 32 + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + def forward(self, x): + sample = x + sample = self.conv_in(sample) + + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class Decoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + act_fn="silu", + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=32, + temb_channels=None, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attn_num_head_channels=None, + temb_channels=None, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = 32 + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + def forward(self, z): + sample = z + sample = self.conv_in(sample) + + # middle + sample = self.mid_block(sample) + + # up + for up_block in self.up_blocks: + sample = up_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class VectorQuantizer(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t()) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: + device = self.parameters.device + sample_device = "cpu" if device.type == "mps" else device + sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device) + x = self.mean + self.std * sample + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) + + def mode(self): + return self.mean + + +class VQModel(ModelMixin, ConfigMixin): + r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray + Kavukcuoglu. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 3, + sample_size: int = 32, + num_vq_embeddings: int = 256, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + double_z=False, + ) + + self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + self.quantize = VectorQuantizer( + num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False + ) + self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + ) + + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: + h = self.encoder(x) + h = self.quant_conv(h) + + if not return_dict: + return (h,) + + return VQEncoderOutput(latents=h) + + def decode( + self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + h = self.encode(x).latents + dec = self.decode(h).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + +class AutoencoderKL(ModelMixin, ConfigMixin): + r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma + and Max Welling. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + sample_size: int = 32, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + ) + + self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/local_diffusers/onnx_utils.py b/local_diffusers/onnx_utils.py new file mode 100644 index 000000000..e840565dd --- /dev/null +++ b/local_diffusers/onnx_utils.py @@ -0,0 +1,189 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil +from pathlib import Path +from typing import Optional, Union + +import numpy as np + +from huggingface_hub import hf_hub_download + +from .utils import is_onnx_available, logging + + +if is_onnx_available(): + import onnxruntime as ort + + +ONNX_WEIGHTS_NAME = "model.onnx" + + +logger = logging.get_logger(__name__) + + +class OnnxRuntimeModel: + base_model_prefix = "onnx_model" + + def __init__(self, model=None, **kwargs): + logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") + self.model = model + self.model_save_dir = kwargs.get("model_save_dir", None) + self.latest_model_name = kwargs.get("latest_model_name", "model.onnx") + + def __call__(self, **kwargs): + inputs = {k: np.array(v) for k, v in kwargs.items()} + return self.model.run(None, inputs) + + @staticmethod + def load_model(path: Union[str, Path], provider=None): + """ + Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` + + Arguments: + path (`str` or `Path`): + Directory from which to load + provider(`str`, *optional*): + Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider` + """ + if provider is None: + logger.info("No onnxruntime provider specified, using CPUExecutionProvider") + provider = "CPUExecutionProvider" + + return ort.InferenceSession(path, providers=[provider]) + + def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the + latest_model_name. + + Arguments: + save_directory (`str` or `Path`): + Directory where to save the model file. + file_name(`str`, *optional*): + Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the + model with a different name. + """ + model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + + src_path = self.model_save_dir.joinpath(self.latest_model_name) + dst_path = Path(save_directory).joinpath(model_file_name) + if not src_path.samefile(dst_path): + shutil.copyfile(src_path, dst_path) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + **kwargs, + ): + """ + Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class + method.: + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + # saving model weights/files + self._save_pretrained(save_directory, **kwargs) + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + use_auth_token: Optional[Union[bool, str, None]] = None, + revision: Optional[Union[str, None]] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + provider: Optional[str] = None, + **kwargs, + ): + """ + Load a model from a directory or the HF Hub. + + Arguments: + model_id (`str` or `Path`): + Directory from which to load + use_auth_token (`str` or `bool`): + Is needed to load models from a private or gated repository + revision (`str`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id + cache_dir (`Union[str, Path]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + file_name(`str`): + Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load + different model files from the same repository or directory. + provider(`str`): + The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`. + kwargs (`Dict`, *optional*): + kwargs will be passed to the model during initialization + """ + model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + # load model from local directory + if os.path.isdir(model_id): + model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider) + kwargs["model_save_dir"] = Path(model_id) + # load model from hub + else: + # download model + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=model_file_name, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + ) + kwargs["model_save_dir"] = Path(model_cache_path).parent + kwargs["latest_model_name"] = Path(model_cache_path).name + model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider) + return cls(model=model, **kwargs) + + @classmethod + def from_pretrained( + cls, + model_id: Union[str, Path], + force_download: bool = True, + use_auth_token: Optional[str] = None, + cache_dir: Optional[str] = None, + **model_kwargs, + ): + revision = None + if len(str(model_id).split("@")) == 2: + model_id, revision = model_id.split("@") + + return cls._from_pretrained( + model_id=model_id, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + use_auth_token=use_auth_token, + **model_kwargs, + ) diff --git a/local_diffusers/optimization.py b/local_diffusers/optimization.py new file mode 100644 index 000000000..e7b836b4a --- /dev/null +++ b/local_diffusers/optimization.py @@ -0,0 +1,275 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch optimization for diffusion models.""" + +import math +from enum import Enum +from typing import Optional, Union + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +from .utils import logging + + +logger = logging.get_logger(__name__) + + +class SchedulerType(Enum): + LINEAR = "linear" + COSINE = "cosine" + COSINE_WITH_RESTARTS = "cosine_with_restarts" + POLYNOMIAL = "polynomial" + CONSTANT = "constant" + CONSTANT_WITH_WARMUP = "constant_with_warmup" + + +def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) + + +def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate + increases linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) + return 1.0 + + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) + + +def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) + ) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases + linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`int`, *optional*, defaults to 1): + The number of hard restarts to use. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + if progress >= 1.0: + return 0.0 + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_polynomial_decay_schedule_with_warmup( + optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 +): + """ + Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the + optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + lr_end (`float`, *optional*, defaults to 1e-7): + The end LR. + power (`float`, *optional*, defaults to 1.0): + Power factor. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT + implementation at + https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + + """ + + lr_init = optimizer.defaults["lr"] + if not (lr_init > lr_end): + raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining**power + lr_end + return decay / lr_init # as LambdaLR multiplies by lr_init + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +TYPE_TO_SCHEDULER_FUNCTION = { + SchedulerType.LINEAR: get_linear_schedule_with_warmup, + SchedulerType.COSINE: get_cosine_schedule_with_warmup, + SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, + SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, + SchedulerType.CONSTANT: get_constant_schedule, + SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, +} + + +def get_scheduler( + name: Union[str, SchedulerType], + optimizer: Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, +): + """ + Unified API to get any scheduler from its name. + + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) diff --git a/local_diffusers/pipeline_utils.py b/local_diffusers/pipeline_utils.py new file mode 100644 index 000000000..84ee9e20f --- /dev/null +++ b/local_diffusers/pipeline_utils.py @@ -0,0 +1,417 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import torch + +import diffusers +import PIL +from huggingface_hub import snapshot_download +from PIL import Image +from tqdm.auto import tqdm + +from .configuration_utils import ConfigMixin +from .utils import DIFFUSERS_CACHE, BaseOutput, logging + + +INDEX_FILE = "diffusion_pytorch_model.bin" + + +logger = logging.get_logger(__name__) + + +LOADABLE_CLASSES = { + "diffusers": { + "ModelMixin": ["save_pretrained", "from_pretrained"], + "SchedulerMixin": ["save_config", "from_config"], + "DiffusionPipeline": ["save_pretrained", "from_pretrained"], + "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], + }, + "transformers": { + "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], + "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], + "PreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + }, +} + +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + + +@dataclass +class ImagePipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class DiffusionPipeline(ConfigMixin): + r""" + Base class for all models. + + [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines + and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to: + + - move all PyTorch modules to the device of your choice + - enabling/disabling the progress bar for the denoising iteration + + Class attributes: + + - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all + compenents of the diffusion pipeline. + """ + config_name = "model_index.json" + + def register_modules(self, **kwargs): + # import it here to avoid circular import + from diffusers import pipelines + + for name, module in kwargs.items(): + # retrive library + library = module.__module__.split(".")[0] + + # check if the module is a pipeline module + pipeline_dir = module.__module__.split(".")[-2] + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: + library = pipeline_dir + + # retrive class_name + class_name = module.__class__.__name__ + + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + def save_pretrained(self, save_directory: Union[str, os.PathLike]): + """ + Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to + a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading + method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + self.save_config(save_directory) + + model_index_dict = dict(self.config) + model_index_dict.pop("_class_name") + model_index_dict.pop("_diffusers_version") + model_index_dict.pop("_module", None) + + for pipeline_component_name in model_index_dict.keys(): + sub_model = getattr(self, pipeline_component_name) + model_cls = sub_model.__class__ + + save_method_name = None + # search for the model's base class in LOADABLE_CLASSES + for library_name, library_classes in LOADABLE_CLASSES.items(): + library = importlib.import_module(library_name) + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class) + if issubclass(model_cls, class_candidate): + # if we found a suitable base class in LOADABLE_CLASSES then grab its save method + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + save_method = getattr(sub_model, save_method_name) + save_method(os.path.join(save_directory, pipeline_component_name)) + + def to(self, torch_device: Optional[Union[str, torch.device]] = None): + if torch_device is None: + return self + + module_names, _ = self.extract_init_dict(dict(self.config)) + for name in module_names.keys(): + module = getattr(self, name) + if isinstance(module, torch.nn.Module): + module.to(torch_device) + return self + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + module_names, _ = self.extract_init_dict(dict(self.config)) + for name in module_names.keys(): + module = getattr(self, name) + if isinstance(module, torch.nn.Module): + return module.device + return torch.device("cpu") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. + + The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on + https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like + `CompVis/ldm-text2im-large-256`. + - A path to a *directory* containing pipeline weights saved using + [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. specify the folder name here. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the + speficic pipeline class. The overritten components are then directly passed to the pipelines `__init__` + method. See example below for more information. + + + + Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.* + `"CompVis/stable-diffusion-v1-4"` + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + Examples: + + ```py + >>> from diffusers import DiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # Download pipeline that requires an authorization token + >>> # For more information on access tokens, please refer to this section + >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) + >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) + + >>> # Download pipeline, but overwrite scheduler + >>> from diffusers import LMSDiscreteScheduler + + >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + >>> pipeline = DiffusionPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True + ... ) + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + provider = kwargs.pop("provider", None) + + # 1. Download the checkpoints and configs + # use snapshot download here to get it working from from_pretrained + if not os.path.isdir(pretrained_model_name_or_path): + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + else: + cached_folder = pretrained_model_name_or_path + + config_dict = cls.get_config_dict(cached_folder) + + # 2. Load the pipeline class, if using custom module then load it from the hub + # if we load from explicit class, let's use it + if cls != DiffusionPipeline: + pipeline_class = cls + else: + diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) + pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + + init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + + init_kwargs = {} + + # import it here to avoid circular import + from diffusers import pipelines + + # 3. Load each module in the pipeline + for name, (library_name, class_name) in init_dict.items(): + is_pipeline_module = hasattr(pipelines, library_name) + loaded_sub_model = None + + # if the model is in a pipeline module, then we load it from the pipeline + if name in passed_class_obj: + # 1. check that passed_class_obj has correct parent class + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" + f" {expected_class_obj}" + ) + else: + logger.warn( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + # set passed class object + loaded_sub_model = passed_class_obj[name] + elif is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + importable_classes = ALL_IMPORTABLE_CLASSES + class_candidates = {c: class_obj for c in importable_classes.keys()} + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + if loaded_sub_model is None: + load_method_name = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + load_method = getattr(class_obj, load_method_name) + + loading_kwargs = {} + if issubclass(class_obj, torch.nn.Module): + loading_kwargs["torch_dtype"] = torch_dtype + if issubclass(class_obj, diffusers.OnnxRuntimeModel): + loading_kwargs["provider"] = provider + + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) + else: + # else load from the root directory + loaded_sub_model = load_method(cached_folder, **loading_kwargs) + + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + + # 4. Instantiate the pipeline + model = pipeline_class(**init_kwargs) + return model + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + def progress_bar(self, iterable): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + return tqdm(iterable, **self._progress_bar_config) + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs diff --git a/local_diffusers/pipelines/__init__.py b/local_diffusers/pipelines/__init__.py new file mode 100644 index 000000000..3e2aeb4fb --- /dev/null +++ b/local_diffusers/pipelines/__init__.py @@ -0,0 +1,19 @@ +from ..utils import is_onnx_available, is_transformers_available +from .ddim import DDIMPipeline +from .ddpm import DDPMPipeline +from .latent_diffusion_uncond import LDMPipeline +from .pndm import PNDMPipeline +from .score_sde_ve import ScoreSdeVePipeline +from .stochastic_karras_ve import KarrasVePipeline + + +if is_transformers_available(): + from .latent_diffusion import LDMTextToImagePipeline + from .stable_diffusion import ( + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + ) + +if is_transformers_available() and is_onnx_available(): + from .stable_diffusion import StableDiffusionOnnxPipeline diff --git a/local_diffusers/pipelines/ddim/__init__.py b/local_diffusers/pipelines/ddim/__init__.py new file mode 100644 index 000000000..8fd31868a --- /dev/null +++ b/local_diffusers/pipelines/ddim/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline_ddim import DDIMPipeline diff --git a/local_diffusers/pipelines/ddim/pipeline_ddim.py b/local_diffusers/pipelines/ddim/pipeline_ddim.py new file mode 100644 index 000000000..33f6064db --- /dev/null +++ b/local_diffusers/pipelines/ddim/pipeline_ddim.py @@ -0,0 +1,117 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +import warnings +from typing import Optional, Tuple, Union + +import torch + +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class DDIMPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + def __init__(self, unet, scheduler): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[torch.Generator] = None, + eta: float = 0.0, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + eta (`float`, *optional*, defaults to 0.0): + The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + # eta corresponds to η in paper and should be between [0, 1] + + # Sample gaussian noise to begin loop + image = torch.randn( + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + generator=generator, + ) + image = image.to(self.device) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(image, t).sample + + # 2. predict previous mean of image x_t-1 and add variance depending on eta + # do x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image, eta).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/local_diffusers/pipelines/ddpm/__init__.py b/local_diffusers/pipelines/ddpm/__init__.py new file mode 100644 index 000000000..8889bdae1 --- /dev/null +++ b/local_diffusers/pipelines/ddpm/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline_ddpm import DDPMPipeline diff --git a/local_diffusers/pipelines/ddpm/pipeline_ddpm.py b/local_diffusers/pipelines/ddpm/pipeline_ddpm.py new file mode 100644 index 000000000..71103bbe4 --- /dev/null +++ b/local_diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -0,0 +1,106 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +import warnings +from typing import Optional, Tuple, Union + +import torch + +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class DDPMPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + def __init__(self, unet, scheduler): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + # Sample gaussian noise to begin loop + image = torch.randn( + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + generator=generator, + ) + image = image.to(self.device) + + # set step values + self.scheduler.set_timesteps(1000) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(image, t).sample + + # 2. compute previous image: x_t -> t_t-1 + image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/local_diffusers/pipelines/latent_diffusion/__init__.py b/local_diffusers/pipelines/latent_diffusion/__init__.py new file mode 100644 index 000000000..c481b38cf --- /dev/null +++ b/local_diffusers/pipelines/latent_diffusion/__init__.py @@ -0,0 +1,6 @@ +# flake8: noqa +from ...utils import is_transformers_available + + +if is_transformers_available(): + from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline diff --git a/local_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/local_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py new file mode 100644 index 000000000..b39840f24 --- /dev/null +++ b/local_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -0,0 +1,705 @@ +import inspect +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_outputs import BaseModelOutput +from transformers.modeling_utils import PreTrainedModel +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import logging + +from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler + + +class LDMTextToImagePipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](ttps://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vqvae: Union[VQModel, AutoencoderKL], + bert: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + unet: Union[UNet2DModel, UNet2DConditionModel], + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 256, + width: Optional[int] = 256, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 1.0, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 256): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 256): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 1.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at + the, usually at the expense of lower image quality. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get unconditional embeddings for classifier free guidance + if guidance_scale != 1.0: + uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") + uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0] + + # get prompt text embeddings + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") + text_embeddings = self.bert(text_input.input_ids.to(self.device))[0] + + latents = torch.randn( + (batch_size, self.unet.in_channels, height // 8, width // 8), + generator=generator, + ) + latents = latents.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale == 1.0: + # guidance_scale of 1 means no guidance + latents_input = latents + context = text_embeddings + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = torch.cat([latents] * 2) + context = torch.cat([uncond_embeddings, text_embeddings]) + + # predict the noise residual + noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample + # perform guidance + if guidance_scale != 1.0: + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vqvae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + +################################################################################ +# Code for the text transformer model +################################################################################ +""" PyTorch LDMBERT model.""" + + +logger = logging.get_logger(__name__) + +LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "ldm-bert", + # See all LDMBert models at https://huggingface.co/models?filter=ldmbert +] + + +LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "ldm-bert": "https://huggingface.co/ldm-bert/resolve/main/config.json", +} + + +""" LDMBERT model configuration""" + + +class LDMBertConfig(PretrainedConfig): + model_type = "ldmbert" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=30522, + max_position_embeddings=77, + encoder_layers=32, + encoder_ffn_dim=5120, + encoder_attention_heads=8, + head_dim=64, + encoder_layerdrop=0.0, + activation_function="gelu", + d_model=1280, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=False, + use_cache=True, + pad_token_id=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.head_dim = head_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + super().__init__(pad_token_id=pad_token_id, **kwargs) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert +class LDMBertAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = head_dim + self.inner_dim = head_dim * num_heads + + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.out_proj = nn.Linear(self.inner_dim, embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class LDMBertEncoderLayer(nn.Module): + def __init__(self, config: LDMBertConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = LDMBertAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + head_dim=config.head_dim, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert +class LDMBertPreTrainedModel(PreTrainedModel): + config_class = LDMBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (LDMBertEncoder,)): + module.gradient_checkpointing = value + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class LDMBertEncoder(LDMBertPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`LDMBertEncoderLayer`]. + + Args: + config: LDMBertConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: LDMBertConfig): + super().__init__(config) + + self.dropout = config.dropout + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) + self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim) + self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + seq_len = input_shape[1] + if position_ids is None: + position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1)) + embed_pos = self.embed_positions(position_ids) + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class LDMBertModel(LDMBertPreTrainedModel): + def __init__(self, config: LDMBertConfig): + super().__init__(config) + self.model = LDMBertEncoder(config) + self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return outputs diff --git a/local_diffusers/pipelines/latent_diffusion_uncond/__init__.py b/local_diffusers/pipelines/latent_diffusion_uncond/__init__.py new file mode 100644 index 000000000..0826ca753 --- /dev/null +++ b/local_diffusers/pipelines/latent_diffusion_uncond/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline_latent_diffusion_uncond import LDMPipeline diff --git a/local_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/local_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py new file mode 100644 index 000000000..4979d88fe --- /dev/null +++ b/local_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -0,0 +1,108 @@ +import inspect +import warnings +from typing import Optional, Tuple, Union + +import torch + +from ...models import UNet2DModel, VQModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler + + +class LDMPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latens. + """ + + def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[torch.Generator] = None, + eta: float = 0.0, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + Number of images to generate. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + latents = torch.randn( + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + generator=generator, + ) + latents = latents.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(self.scheduler.timesteps): + # predict the noise residual + noise_prediction = self.unet(latents, t).sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample + + # decode the image latents with the VAE + image = self.vqvae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/local_diffusers/pipelines/pndm/__init__.py b/local_diffusers/pipelines/pndm/__init__.py new file mode 100644 index 000000000..6fc46aaab --- /dev/null +++ b/local_diffusers/pipelines/pndm/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline_pndm import PNDMPipeline diff --git a/local_diffusers/pipelines/pndm/pipeline_pndm.py b/local_diffusers/pipelines/pndm/pipeline_pndm.py new file mode 100644 index 000000000..f3dff1a9a --- /dev/null +++ b/local_diffusers/pipelines/pndm/pipeline_pndm.py @@ -0,0 +1,111 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +import warnings +from typing import Optional, Tuple, Union + +import torch + +from ...models import UNet2DModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import PNDMScheduler + + +class PNDMPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image. + """ + + unet: UNet2DModel + scheduler: PNDMScheduler + + def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, `optional`, defaults to 1): The number of images to generate. + num_inference_steps (`int`, `optional`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator`, `optional`): A [torch + generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose + between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a + [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + # For more information on the sampling method you can take a look at Algorithm 2 of + # the official paper: https://arxiv.org/pdf/2202.09778.pdf + + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + # Sample gaussian noise to begin loop + image = torch.randn( + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + generator=generator, + ) + image = image.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + for t in self.progress_bar(self.scheduler.timesteps): + model_output = self.unet(image, t).sample + + image = self.scheduler.step(model_output, t, image).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/local_diffusers/pipelines/score_sde_ve/__init__.py b/local_diffusers/pipelines/score_sde_ve/__init__.py new file mode 100644 index 000000000..000d61f6e --- /dev/null +++ b/local_diffusers/pipelines/score_sde_ve/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline_score_sde_ve import ScoreSdeVePipeline diff --git a/local_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/local_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py new file mode 100644 index 000000000..604e2b54c --- /dev/null +++ b/local_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +import warnings +from typing import Optional, Tuple, Union + +import torch + +from ...models import UNet2DModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import ScoreSdeVeScheduler + + +class ScoreSdeVePipeline(DiffusionPipeline): + r""" + Parameters: + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): + The [`ScoreSdeVeScheduler`] scheduler to be used in combination with `unet` to denoise the encoded image. + """ + unet: UNet2DModel + scheduler: ScoreSdeVeScheduler + + def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 2000, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + img_size = self.unet.config.sample_size + shape = (batch_size, 3, img_size, img_size) + + model = self.unet + + sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max + sample = sample.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_sigmas(num_inference_steps) + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) + + # correction step + for _ in range(self.scheduler.correct_steps): + model_output = self.unet(sample, sigma_t).sample + sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample + + # prediction step + model_output = model(sample, sigma_t).sample + output = self.scheduler.step_pred(model_output, t, sample, generator=generator) + + sample, sample_mean = output.prev_sample, output.prev_sample_mean + + sample = sample_mean.clamp(0, 1) + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + sample = self.numpy_to_pil(sample) + + if not return_dict: + return (sample,) + + return ImagePipelineOutput(images=sample) diff --git a/local_diffusers/pipelines/stable_diffusion/__init__.py b/local_diffusers/pipelines/stable_diffusion/__init__.py new file mode 100644 index 000000000..5ffda93f1 --- /dev/null +++ b/local_diffusers/pipelines/stable_diffusion/__init__.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np + +import PIL +from PIL import Image + +from ...utils import BaseOutput, is_onnx_available, is_transformers_available + + +@dataclass +class StableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: List[bool] + + +if is_transformers_available(): + from .pipeline_stable_diffusion import StableDiffusionPipeline + from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline + from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline + from .safety_checker import StableDiffusionSafetyChecker + +if is_transformers_available() and is_onnx_available(): + from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline diff --git a/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py new file mode 100644 index 000000000..8e3199b44 --- /dev/null +++ b/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -0,0 +1,397 @@ +# Modification of the original file by O. Teytaud for facilitating genetic stable diffusion. + +import inspect +import os +import numpy as np +import random +import warnings +from typing import List, Optional, Union + +import torch + +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +class StableDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + +# def get_latent(self, image): +# return self.vae.encode(image) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_device = "cpu" if self.device.type == "mps" else self.device + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + latents_intermediate_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + speedup = 1 + if latents is None: + latents = torch.randn( + latents_intermediate_shape, + generator=generator, + device=latents_device, + ) + if len(os.environ["forcedlatent"]) > 10: + stri = os.environ["forcedlatent"] + print(f"we get a forcing for the latent z: {stri[:20]}.") + if len(eval(stri)) == 1: + stri = str(eval(stri)[0]) + speedup = 1 + latents = np.array(list(eval(stri))).flatten() + #latents = latents + np.exp(0.1 * np.random.randn()) * np.random.rand(len(latents)) + #latents = np.sqrt(len(latents) / np.sum(latents ** 2)) * latents + #latents = np.sqrt(len(latents)) * latents / np.sqrt(np.sum(latents ** 2)) + print(f"As an array, this is {latents[:10]}") + print(f"immediately after loading latent ==> {sum(latents.flatten()**2) / len(latents.flatten())}") + latents = torch.from_numpy(latents.reshape((1,4,64,64))).float().to(latents_device) + os.environ["forcedlatent"] = "" + good = eval(os.environ["good"]) + bad = eval(os.environ["bad"]) + print(f"{len(good)} good and {len(bad)} bad") + i_believe_in_evolution = len(good) > 0 and len(bad) > 10 + print(f"I believe in evolution = {i_believe_in_evolution}") + if i_believe_in_evolution: + from sklearn import tree + from sklearn.neural_network import MLPClassifier + #from sklearn.neighbors import NearestCentroid + from sklearn.linear_model import LogisticRegression + #z = (np.random.randn(4*64*64)) + z = latents.cpu().numpy().flatten() + if os.environ.get("skl", "tree") == "tree": + clf = tree.DecisionTreeClassifier()#min_samples_split=0.1) + elif os.environ.get("skl", "tree") == "logit": + clf = LogisticRegression() + else: + clf = MLPClassifier(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(5, 2), random_state=1) + #clf = NearestCentroid() + + + + X=good + bad + Y = [1] * len(good) + [0] * len(bad) + clf = clf.fit(X,Y) + epsilon = 0.0001 # for astronauts + epsilon = 1.0 + + def loss(x): + return clf.predict_proba([x])[0][0] # for astronauts + #return clf.predict_proba([(1-epsilon)*z+epsilon*x])[0][0] # for astronauts + #return clf.predict_proba([z+epsilon*x])[0][0] + + + budget = int(os.environ.get("budget", "300")) + if i_believe_in_evolution and budget > 20: + import nevergrad as ng + #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) + #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) + optim_class = ng.optimizers.registry[os.environ.get("ngoptim", "DiscreteLenglerOnePlusOne")] + #nevergrad_optimizer = ng.optimizers.DiscreteLenglerOnePlusOne(len(z), budget) + nevergrad_optimizer = optim_class(len(z), budget) + #nevergrad_optimizer = ng.optimizers.DiscreteOnePlusOne(len(z), budget) +# for k in range(5): +# z1 = np.array(random.choice(good)) +# z2 = np.array(random.choice(good)) +# z3 = np.array(random.choice(good)) +# z4 = np.array(random.choice(good)) +# z5 = np.array(random.choice(good)) +# #z = 0.99 * z1 + 0.01 * (z2+z3+z4+z5)/4. +# z = 0.2 * (z1 + z2 + z3 + z4 + z5) +# mu = int(os.environ.get("mu", "5")) +# parents = [z1, z2, z3, z4, z5] +# weights = [np.exp(np.random.randn() - i * float(os.environ.get("decay", "1."))) for i in range(5)] +# z = weights[0] * z1 +# for u in range(mu): +# if u > 0: +# z += weights[u] * parents[u] +# z = (1. / sum(weights[:mu])) * z +# z = np.sqrt(len(z)) * z / np.linalg.norm(z) +# +# #for u in range(len(z)): +# # z[u] = random.choice([z1[u],z2[u],z3[u],z4[u],z5[u]]) +# nevergrad_optimizer.suggest + if len(os.environ["forcedlatent"]) > 0: + print("we get a forcing for the latent z.") + z0 = eval(os.environ["forcedlatent"]) + #nevergrad_optimizer.suggest(eval(os.environ["forcedlatent"])) + else: + z0 = z + for i in range(budget): + x = nevergrad_optimizer.ask() + z = z0 + float(os.environ.get("epsilon", "0.001")) * x.value + z = np.sqrt(len(z)) * z / np.linalg.norm(z) + l = loss(z) + nevergrad_optimizer.tell(x, l) + if np.log2(i+1) == int(np.log2(i+1)): + print(f"iteration {i} --> {l}") + print("var/variable = ", sum(z**2)/len(z)) + #z = (1.-epsilon) * z + epsilon * x / np.sqrt(np.sum(x ** 2)) + if l < 0.0000001 and os.environ.get("earlystop", "False") in ["true", "True"]: + print(f"we find proba(bad)={l}") + break + x = nevergrad_optimizer.recommend().value + z = z0 + float(os.environ.get("epsilon", "0.001")) * x + z = np.sqrt(len(z)) * z / np.linalg.norm(z) + latents = torch.from_numpy(z.reshape(latents_intermediate_shape)).float() #.half() + else: + if latents.shape != latents_intermediate_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_intermediate_shape}") + print(f"latent ==> {sum(latents.flatten()**2) / len(latents.flatten())}") + print(f"latent ==> {torch.max(latents)}") + print(f"latent ==> {torch.min(latents)}") + os.environ["latent_sd"] = str(list(latents.flatten().cpu().numpy())) + for i in [2, 3]: + latents = torch.repeat_interleave(latents, repeats=latents_shape[i] // latents_intermediate_shape[i], dim=i) #/ np.sqrt(np.sqrt(latents_shape[i] // latents_intermediate_shape[i])) + latents = latents.float().to(self.device) + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps // speedup, **extra_set_kwargs) + + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[i] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # scale and decode the image latents with vae + #os.environ["latent_sd"] = str(list(latents.flatten().cpu().detach().numpy())) + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + # run safety checker + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py new file mode 100644 index 000000000..475ceef4f --- /dev/null +++ b/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -0,0 +1,291 @@ +import inspect +from typing import List, Optional, Union + +import numpy as np +import torch + +import PIL +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +class StableDiffusionImg2ImgPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `set_attention_slice` + self.enable_attention_slice(None) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + offset = 0 + if accepts_offset: + offset = 1 + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + if not isinstance(init_image, torch.FloatTensor): + init_image = preprocess(init_image) + + # encode the init image into latents and scale the latents + init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + # expand init_latents for batch_size + init_latents = torch.cat([init_latents] * batch_size) + + # get the original timestep using init_timestep + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + if isinstance(self.scheduler, LMSDiscreteScheduler): + timesteps = torch.tensor( + [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device + ) + else: + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device) + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + t_start = max(num_inference_steps - init_timestep + offset, 0) + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])): + t_index = t_start + i + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[t_index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = latent_model_input.to(self.unet.dtype) + t = t.to(self.unet.dtype) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents.to(self.vae.dtype)).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + # run safety checker + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py new file mode 100644 index 000000000..05ea84ae0 --- /dev/null +++ b/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -0,0 +1,309 @@ +import inspect +from typing import List, Optional, Union + +import numpy as np +import torch + +import PIL +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, PNDMScheduler +from ...utils import logging +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +class StableDiffusionInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + scheduler = scheduler.set_format("pt") + logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `set_attention_slice` + self.enable_attention_slice(None) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be + converted to a single channel (luminance) before use. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + offset = 0 + if accepts_offset: + offset = 1 + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + # preprocess image + init_image = preprocess_image(init_image).to(self.device) + + # encode the init image into latents and scale the latents + init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + + init_latents = 0.18215 * init_latents + + # Expand init_latents for batch_size + init_latents = torch.cat([init_latents] * batch_size) + init_latents_orig = init_latents + + # preprocess mask + mask = preprocess_mask(mask_image).to(self.device) + mask = torch.cat([mask] * batch_size) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + t_start = max(num_inference_steps - init_timestep + offset, 0) + for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + # run safety checker + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py new file mode 100644 index 000000000..7ff3ff22f --- /dev/null +++ b/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -0,0 +1,165 @@ +import inspect +from typing import List, Optional, Union + +import numpy as np + +from transformers import CLIPFeatureExtractor, CLIPTokenizer + +from ...onnx_utils import OnnxRuntimeModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from . import StableDiffusionPipelineOutput + + +class StableDiffusionOnnxPipeline(DiffusionPipeline): + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPFeatureExtractor + + def __init__( + self, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + scheduler = scheduler.set_format("np") + self.register_modules( + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + latents: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ): + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ) + uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + latents_shape = (batch_size, 4, height // 8, width // 8) + if latents is None: + latents = np.random.randn(*latents_shape).astype(np.float32) + elif latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[i] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + # predict the noise residual + noise_pred = self.unet( + sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings + ) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae_decoder(latent_sample=latents)[0] + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + # run safety checker + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") + image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/local_diffusers/pipelines/stable_diffusion/safety_checker.py b/local_diffusers/pipelines/stable_diffusion/safety_checker.py new file mode 100644 index 000000000..3ebc05c91 --- /dev/null +++ b/local_diffusers/pipelines/stable_diffusion/safety_checker.py @@ -0,0 +1,106 @@ +import numpy as np +import torch +import torch.nn as nn + +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class StableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.register_buffer("concept_embeds_weights", torch.ones(17)) + self.register_buffer("special_care_embeds_weights", torch.ones(3)) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concet_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concet_idx] + concept_threshold = self.special_care_embeds_weights[concet_idx].item() + result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concet_idx] > 0: + result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]}) + adjustment = 0.01 + + for concet_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concet_idx] + concept_threshold = self.concept_embeds_weights[concet_idx].item() + result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concet_idx] > 0: + result_img["bad_concepts"].append(concet_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + #for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + # if has_nsfw_concept: + # images[idx] = np.zeros(images[idx].shape) # black image +# +# if any(has_nsfw_concepts): +# logger.warning( +# "Potential NSFW content was detected in one or more images. A black image will be returned instead." +# " Try again with a different prompt and/or seed." +# ) + + return images, has_nsfw_concepts + + @torch.inference_mode() + def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + images[has_nsfw_concepts] = 0.0 # black image + + return images, has_nsfw_concepts diff --git a/local_diffusers/pipelines/stochastic_karras_ve/__init__.py b/local_diffusers/pipelines/stochastic_karras_ve/__init__.py new file mode 100644 index 000000000..db2582043 --- /dev/null +++ b/local_diffusers/pipelines/stochastic_karras_ve/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline_stochastic_karras_ve import KarrasVePipeline diff --git a/local_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/local_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py new file mode 100644 index 000000000..15266544d --- /dev/null +++ b/local_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +import warnings +from typing import Optional, Tuple, Union + +import torch + +from ...models import UNet2DModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import KarrasVeScheduler + + +class KarrasVePipeline(DiffusionPipeline): + r""" + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and + the VE column of Table 1 from [1] for reference. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic + differential equations." https://arxiv.org/abs/2011.13456 + + Parameters: + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`KarrasVeScheduler`]): + Scheduler for the diffusion process to be used in combination with `unet` to denoise the encoded image. + """ + + # add type hints for linting + unet: UNet2DModel + scheduler: KarrasVeScheduler + + def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + img_size = self.unet.config.sample_size + shape = (batch_size, 3, img_size, img_size) + + model = self.unet + + # sample x_0 ~ N(0, sigma_0^2 * I) + sample = torch.randn(*shape) * self.scheduler.config.sigma_max + sample = sample.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # here sigma_t == t_i from the paper + sigma = self.scheduler.schedule[t] + sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 + + # 1. Select temporarily increased noise level sigma_hat + # 2. Add new noise to move from sample_i to sample_hat + sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator) + + # 3. Predict the noise residual given the noise magnitude `sigma_hat` + # The model inputs and output are adjusted by following eq. (213) in [1]. + model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample + + # 4. Evaluate dx/dt at sigma_hat + # 5. Take Euler step from sigma to sigma_prev + step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat) + + if sigma_prev != 0: + # 6. Apply 2nd order correction + # The model inputs and output are adjusted by following eq. (213) in [1]. + model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample + step_output = self.scheduler.step_correct( + model_output, + sigma_hat, + sigma_prev, + sample_hat, + step_output.prev_sample, + step_output["derivative"], + ) + sample = step_output.prev_sample + + sample = (sample / 2 + 0.5).clamp(0, 1) + image = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(sample) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/local_diffusers/schedulers/__init__.py b/local_diffusers/schedulers/__init__.py new file mode 100644 index 000000000..20c25f351 --- /dev/null +++ b/local_diffusers/schedulers/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..utils import is_scipy_available +from .scheduling_ddim import DDIMScheduler +from .scheduling_ddpm import DDPMScheduler +from .scheduling_karras_ve import KarrasVeScheduler +from .scheduling_pndm import PNDMScheduler +from .scheduling_sde_ve import ScoreSdeVeScheduler +from .scheduling_sde_vp import ScoreSdeVpScheduler +from .scheduling_utils import SchedulerMixin + + +if is_scipy_available(): + from .scheduling_lms_discrete import LMSDiscreteScheduler +else: + from ..utils.dummy_scipy_objects import * # noqa F403 diff --git a/local_diffusers/schedulers/scheduling_ddim.py b/local_diffusers/schedulers/scheduling_ddim.py new file mode 100644 index 000000000..894d63bf2 --- /dev/null +++ b/local_diffusers/schedulers/scheduling_ddim.py @@ -0,0 +1,261 @@ +# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas, dtype=np.float32) + + +class DDIMScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functios. + + For more details, see the original paper: https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): TODO + timestep_values (`np.ndarray`, optional): TODO + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + set_alpha_to_one (`bool`, default `True`): + if alpha for final step is 1 or the final alpha of the "non-previous" one. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + timestep_values: Optional[np.ndarray] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + tensor_format: str = "pt", + ): + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this paratemer simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def set_timesteps(self, num_inference_steps: int, offset: int = 0): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + offset (`int`): TODO + """ + self.num_inference_steps = num_inference_steps + self.timesteps = np.arange( + 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps + )[::-1].copy() + self.timesteps += offset + self.set_format(tensor_format=self.tensor_format) + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): TODO + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointingc to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = self.clip(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the model_output is always re-derived from the clipped x_0 in Glide + model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + device = model_output.device if torch.is_tensor(model_output) else "cpu" + noise = torch.randn(model_output.shape, generator=generator).to(device) + variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise + + if not torch.is_tensor(model_output): + variance = variance.numpy() + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: Union[torch.FloatTensor, np.ndarray], + noise: Union[torch.FloatTensor, np.ndarray], + timesteps: Union[torch.IntTensor, np.ndarray], + ) -> Union[torch.FloatTensor, np.ndarray]: + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/local_diffusers/schedulers/scheduling_ddpm.py b/local_diffusers/schedulers/scheduling_ddpm.py new file mode 100644 index 000000000..4fbfb9038 --- /dev/null +++ b/local_diffusers/schedulers/scheduling_ddpm.py @@ -0,0 +1,264 @@ +# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas, dtype=np.float32) + + +class DDPMScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functios. + + For more details, see the original paper: https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): TODO + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + variance_type: str = "fixed_small", + clip_sample: bool = True, + tensor_format: str = "pt", + ): + + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + elif beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.one = np.array(1.0) + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + self.variance_type = variance_type + + def set_timesteps(self, num_inference_steps: int): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) + self.num_inference_steps = num_inference_steps + self.timesteps = np.arange( + 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps + )[::-1].copy() + self.set_format(tensor_format=self.tensor_format) + + def _get_variance(self, t, predicted_variance=None, variance_type=None): + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] + + if variance_type is None: + variance_type = self.config.variance_type + + # hacks - were probs added for training stability + if variance_type == "fixed_small": + variance = self.clip(variance, min_value=1e-20) + # for rl-diffuser https://arxiv.org/abs/2205.09991 + elif variance_type == "fixed_small_log": + variance = self.log(self.clip(variance, min_value=1e-20)) + elif variance_type == "fixed_large": + variance = self.betas[t] + elif variance_type == "fixed_large_log": + # Glide max_log + variance = self.log(self.betas[t]) + elif variance_type == "learned": + return predicted_variance + elif variance_type == "learned_range": + min_log = variance + max_log = self.betas[t] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + predict_epsilon=True, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + predict_epsilon (`bool`): + optional flag to use when model predicts the samples directly instead of the noise, epsilon. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + t = timestep + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if predict_epsilon: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + else: + pred_original_sample = model_output + + # 3. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = self.clip(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t + current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if t > 0: + noise = self.randn_like(model_output, generator=generator) + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample,) + + return SchedulerOutput(prev_sample=pred_prev_sample) + + def add_noise( + self, + original_samples: Union[torch.FloatTensor, np.ndarray], + noise: Union[torch.FloatTensor, np.ndarray], + timesteps: Union[torch.IntTensor, np.ndarray], + ) -> Union[torch.FloatTensor, np.ndarray]: + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/local_diffusers/schedulers/scheduling_karras_ve.py b/local_diffusers/schedulers/scheduling_karras_ve.py new file mode 100644 index 000000000..3a2370cfc --- /dev/null +++ b/local_diffusers/schedulers/scheduling_karras_ve.py @@ -0,0 +1,208 @@ +# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class KarrasVeOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Derivate of predicted original image sample (x_0). + """ + + prev_sample: torch.FloatTensor + derivative: torch.FloatTensor + + +class KarrasVeScheduler(SchedulerMixin, ConfigMixin): + """ + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and + the VE column of Table 1 from [1] for reference. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic + differential equations." https://arxiv.org/abs/2011.13456 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functios. + + For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of + Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the + optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. + + Args: + sigma_min (`float`): minimum noise magnitude + sigma_max (`float`): maximum noise magnitude + s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. + A reasonable range is [1.000, 1.011]. + s_churn (`float`): the parameter controlling the overall amount of stochasticity. + A reasonable range is [0, 100]. + s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). + A reasonable range is [0, 10]. + s_max (`float`): the end value of the sigma range where we add noise. + A reasonable range is [0.2, 80]. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + sigma_min: float = 0.02, + sigma_max: float = 100, + s_noise: float = 1.007, + s_churn: float = 80, + s_min: float = 0.05, + s_max: float = 50, + tensor_format: str = "pt", + ): + # setable values + self.num_inference_steps = None + self.timesteps = None + self.schedule = None # sigma(t_i) + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_timesteps(self, num_inference_steps: int): + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + + """ + self.num_inference_steps = num_inference_steps + self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() + self.schedule = [ + (self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1))) + for i in self.timesteps + ] + self.schedule = np.array(self.schedule, dtype=np.float32) + + self.set_format(tensor_format=self.tensor_format) + + def add_noise_to_input( + self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None + ) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]: + """ + Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a + higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. + + TODO Args: + """ + if self.s_min <= sigma <= self.s_max: + gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1) + else: + gamma = 0 + + # sample eps ~ N(0, S_noise^2 * I) + eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) + sigma_hat = sigma + gamma * sigma + sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) + + return sample_hat, sigma_hat + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + sigma_hat: float, + sigma_prev: float, + sample_hat: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[KarrasVeOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). + Returns: + [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`: + [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + + pred_original_sample = sample_hat + sigma_hat * model_output + derivative = (sample_hat - pred_original_sample) / sigma_hat + sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative + + if not return_dict: + return (sample_prev, derivative) + + return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) + + def step_correct( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + sigma_hat: float, + sigma_prev: float, + sample_hat: Union[torch.FloatTensor, np.ndarray], + sample_prev: Union[torch.FloatTensor, np.ndarray], + derivative: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[KarrasVeOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. TODO complete description + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO + sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO + derivative (`torch.FloatTensor` or `np.ndarray`): TODO + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO + + """ + pred_original_sample = sample_prev + sigma_prev * model_output + derivative_corr = (sample_prev - pred_original_sample) / sigma_prev + sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) + + if not return_dict: + return (sample_prev, derivative) + + return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) + + def add_noise(self, original_samples, noise, timesteps): + raise NotImplementedError() diff --git a/local_diffusers/schedulers/scheduling_lms_discrete.py b/local_diffusers/schedulers/scheduling_lms_discrete.py new file mode 100644 index 000000000..1381587fe --- /dev/null +++ b/local_diffusers/schedulers/scheduling_lms_discrete.py @@ -0,0 +1,193 @@ +# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from scipy import integrate + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functios. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): TODO + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + timestep_values (`np.ndarry`, optional): TODO + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + timestep_values: Optional[np.ndarray] = None, + tensor_format: str = "pt", + ): + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.derivatives = [] + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def get_lms_coefficient(self, order, t, current_order): + """ + Compute a linear multistep coefficient. + + Args: + order (TODO): + t (TODO): + current_order (TODO): + """ + + def lms_derivative(tau): + prod = 1.0 + for k in range(order): + if current_order == k: + continue + prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k]) + return prod + + integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0] + + return integrated_coeff + + def set_timesteps(self, num_inference_steps: int): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + + low_idx = np.floor(self.timesteps).astype(int) + high_idx = np.ceil(self.timesteps).astype(int) + frac = np.mod(self.timesteps, 1.0) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + self.sigmas = np.concatenate([sigmas, [0.0]]) + + self.derivatives = [] + + self.set_format(tensor_format=self.tensor_format) + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + order: int = 4, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + order: coefficient for multi-step inference. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + sigma = self.sigmas[timestep] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma * model_output + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + self.derivatives.append(derivative) + if len(self.derivatives) > order: + self.derivatives.pop(0) + + # 3. Compute linear multistep coefficients + order = min(timestep + 1, order) + lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)] + + # 4. Compute previous sample based on the derivatives path + prev_sample = sample + sum( + coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives)) + ) + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: Union[torch.FloatTensor, np.ndarray], + noise: Union[torch.FloatTensor, np.ndarray], + timesteps: Union[torch.IntTensor, np.ndarray], + ) -> Union[torch.FloatTensor, np.ndarray]: + sigmas = self.match_shape(self.sigmas[timesteps], noise) + noisy_samples = original_samples + noise * sigmas + + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/local_diffusers/schedulers/scheduling_pndm.py b/local_diffusers/schedulers/scheduling_pndm.py new file mode 100644 index 000000000..b43d88bba --- /dev/null +++ b/local_diffusers/schedulers/scheduling_pndm.py @@ -0,0 +1,378 @@ +# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas, dtype=np.float32) + + +class PNDMScheduler(SchedulerMixin, ConfigMixin): + """ + Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, + namely Runge-Kutta method and a linear multi-step method. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functios. + + For more details, see the original paper: https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): TODO + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays + skip_prk_steps (`bool`): + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms steps; defaults to `False`. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + tensor_format: str = "pt", + skip_prk_steps: bool = False, + ): + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + + self.one = np.array(1.0) + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + # running values + self.cur_model_output = 0 + self.counter = 0 + self.cur_sample = None + self.ets = [] + + # setable values + self.num_inference_steps = None + self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self._offset = 0 + self.prk_timesteps = None + self.plms_timesteps = None + self.timesteps = None + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + offset (`int`): TODO + """ + self.num_inference_steps = num_inference_steps + self._timesteps = list( + range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) + ) + self._offset = offset + self._timesteps = np.array([t + self._offset for t in self._timesteps]) + + if self.config.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + self.prk_timesteps = np.array([]) + self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[ + ::-1 + ].copy() + else: + prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( + np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order + ) + self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() + self.plms_timesteps = self._timesteps[:-3][ + ::-1 + ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy + + self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + + self.ets = [] + self.counter = 0 + self.set_format(tensor_format=self.tensor_format) + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: + return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) + else: + return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) + + def step_prk( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 + prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1]) + timestep = self.prk_timesteps[self.counter // 4 * 4] + + if self.counter % 4 == 0: + self.cur_model_output += 1 / 6 * model_output + self.ets.append(model_output) + self.cur_sample = sample + elif (self.counter - 1) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 2) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 3) % 4 == 0: + model_output = self.cur_model_output + 1 / 6 * model_output + self.cur_model_output = 0 + + # cur_sample should not be `None` + cur_sample = self.cur_sample if self.cur_sample is not None else sample + + prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) + self.counter += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def step_plms( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if not self.config.skip_prk_steps and len(self.ets) < 3: + raise ValueError( + f"{self.__class__} can only be run AFTER scheduler has been run " + "in 'prk' mode for at least 12 iterations " + "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " + "for more information." + ) + + prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0) + + if self.counter != 1: + self.ets.append(model_output) + else: + prev_timestep = timestep + timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps + + if len(self.ets) == 1 and self.counter == 0: + model_output = model_output + self.cur_sample = sample + elif len(self.ets) == 1 and self.counter == 1: + model_output = (model_output + self.ets[-1]) / 2 + sample = self.cur_sample + self.cur_sample = None + elif len(self.ets) == 2: + model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 + else: + model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) + self.counter += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation + + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # model_output -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset] + alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset] + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) + + # corresponds to denominator of e_θ(x_t, t) in formula (9) + model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev + ) ** (0.5) + + # full formula (9) + prev_sample = ( + sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff + ) + + return prev_sample + + def add_noise( + self, + original_samples: Union[torch.FloatTensor, np.ndarray], + noise: Union[torch.FloatTensor, np.ndarray], + timesteps: Union[torch.IntTensor, np.ndarray], + ) -> torch.Tensor: + # mps requires indices to be in the same device, so we use cpu as is the default with cuda + timesteps = timesteps.to(self.alphas_cumprod.device) + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/local_diffusers/schedulers/scheduling_sde_ve.py b/local_diffusers/schedulers/scheduling_sde_ve.py new file mode 100644 index 000000000..e187f0796 --- /dev/null +++ b/local_diffusers/schedulers/scheduling_sde_ve.py @@ -0,0 +1,283 @@ +# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +@dataclass +class SdeVeOutput(BaseOutput): + """ + Output class for the ScoreSdeVeScheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. + """ + + prev_sample: torch.FloatTensor + prev_sample_mean: torch.FloatTensor + + +class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): + """ + The variance exploding stochastic differential equation (SDE) scheduler. + + For more information, see the original paper: https://arxiv.org/abs/2011.13456 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functios. + + Args: + snr (`float`): + coefficient weighting the step from the model_output sample (from the network) to the random noise. + sigma_min (`float`): + initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the + distribution of the data. + sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. + sampling_eps (`float`): the end value of sampling, where timesteps decrease progessively from 1 to + epsilon. + correct_steps (`int`): number of correction steps performed on a produced sample. + tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler. + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 2000, + snr: float = 0.15, + sigma_min: float = 0.01, + sigma_max: float = 1348.0, + sampling_eps: float = 1e-5, + correct_steps: int = 1, + tensor_format: str = "pt", + ): + # setable values + self.timesteps = None + + self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). + + """ + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + self.timesteps = np.linspace(1, sampling_eps, num_inference_steps) + elif tensor_format == "pt": + self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps) + else: + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def set_sigmas( + self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None + ): + """ + Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. + + The sigmas control the weight of the `drift` and `diffusion` components of sample update. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sigma_min (`float`, optional): + initial noise scale value (overrides value given at Scheduler instantiation). + sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). + sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). + + """ + sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min + sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + if self.timesteps is None: + self.set_timesteps(num_inference_steps, sampling_eps) + + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)) + self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) + elif tensor_format == "pt": + self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)) + self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) + else: + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def get_adjacent_sigma(self, timesteps, t): + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1]) + elif tensor_format == "pt": + return torch.where( + timesteps == 0, + torch.zeros_like(t.to(timesteps.device)), + self.discrete_sigmas[timesteps - 1].to(timesteps.device), + ) + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def set_seed(self, seed): + warnings.warn( + "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a" + " generator instead.", + DeprecationWarning, + ) + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + np.random.seed(seed) + elif tensor_format == "pt": + torch.manual_seed(seed) + else: + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def step_pred( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + **kwargs, + ) -> Union[SdeVeOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if "seed" in kwargs and kwargs["seed"] is not None: + self.set_seed(kwargs["seed"]) + + if self.timesteps is None: + raise ValueError( + "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + timestep = timestep * torch.ones( + sample.shape[0], device=sample.device + ) # torch.repeat_interleave(timestep, sample.shape[0]) + timesteps = (timestep * (len(self.timesteps) - 1)).long() + + # mps requires indices to be in the same device, so we use cpu as is the default with cuda + timesteps = timesteps.to(self.discrete_sigmas.device) + + sigma = self.discrete_sigmas[timesteps].to(sample.device) + adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) + drift = self.zeros_like(sample) + diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 + + # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) + # also equation 47 shows the analog from SDE models to ancestral sampling methods + drift = drift - diffusion[:, None, None, None] ** 2 * model_output + + # equation 6: sample noise for the diffusion term of + noise = self.randn_like(sample, generator=generator) + prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep + # TODO is the variable diffusion the correct scaling term for the noise? + prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g + + if not return_dict: + return (prev_sample, prev_sample_mean) + + return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean) + + def step_correct( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + sample: Union[torch.FloatTensor, np.ndarray], + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + **kwargs, + ) -> Union[SchedulerOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. This is often run repeatedly + after making the prediction for the previous timestep. + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if "seed" in kwargs and kwargs["seed"] is not None: + self.set_seed(kwargs["seed"]) + + if self.timesteps is None: + raise ValueError( + "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" + # sample noise for correction + noise = self.randn_like(sample, generator=generator) + + # compute step size from the model_output, the noise, and the snr + grad_norm = self.norm(model_output) + noise_norm = self.norm(noise) + step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 + step_size = step_size * torch.ones(sample.shape[0]).to(sample.device) + # self.repeat_scalar(step_size, sample.shape[0]) + + # compute corrected sample: model_output term and noise term + prev_sample_mean = sample + step_size[:, None, None, None] * model_output + prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/local_diffusers/schedulers/scheduling_sde_vp.py b/local_diffusers/schedulers/scheduling_sde_vp.py new file mode 100644 index 000000000..66e6ec661 --- /dev/null +++ b/local_diffusers/schedulers/scheduling_sde_vp.py @@ -0,0 +1,81 @@ +# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin + + +class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): + """ + The variance preserving stochastic differential equation (SDE) scheduler. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functios. + + For more information, see the original paper: https://arxiv.org/abs/2011.13456 + + UNDER CONSTRUCTION + + """ + + @register_to_config + def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): + + self.sigmas = None + self.discrete_sigmas = None + self.timesteps = None + + def set_timesteps(self, num_inference_steps): + self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) + + def step_pred(self, score, x, t): + if self.timesteps is None: + raise ValueError( + "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + # TODO(Patrick) better comments + non-PyTorch + # postprocess model score + log_mean_coeff = ( + -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min + ) + std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) + score = -score / std[:, None, None, None] + + # compute + dt = -1.0 / len(self.timesteps) + + beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) + drift = -0.5 * beta_t[:, None, None, None] * x + diffusion = torch.sqrt(beta_t) + drift = drift - diffusion[:, None, None, None] ** 2 * score + x_mean = x + drift * dt + + # add noise + noise = torch.randn_like(x) + x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise + + return x, x_mean + + def __len__(self): + return self.config.num_train_timesteps diff --git a/local_diffusers/schedulers/scheduling_utils.py b/local_diffusers/schedulers/scheduling_utils.py new file mode 100644 index 000000000..f2bcd73ac --- /dev/null +++ b/local_diffusers/schedulers/scheduling_utils.py @@ -0,0 +1,125 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Union + +import numpy as np +import torch + +from ..utils import BaseOutput + + +SCHEDULER_CONFIG_NAME = "scheduler_config.json" + + +@dataclass +class SchedulerOutput(BaseOutput): + """ + Base class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class SchedulerMixin: + """ + Mixin containing common functions for the schedulers. + """ + + config_name = SCHEDULER_CONFIG_NAME + ignore_for_config = ["tensor_format"] + + def set_format(self, tensor_format="pt"): + self.tensor_format = tensor_format + if tensor_format == "pt": + for key, value in vars(self).items(): + if isinstance(value, np.ndarray): + setattr(self, key, torch.from_numpy(value)) + + return self + + def clip(self, tensor, min_value=None, max_value=None): + tensor_format = getattr(self, "tensor_format", "pt") + + if tensor_format == "np": + return np.clip(tensor, min_value, max_value) + elif tensor_format == "pt": + return torch.clamp(tensor, min_value, max_value) + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def log(self, tensor): + tensor_format = getattr(self, "tensor_format", "pt") + + if tensor_format == "np": + return np.log(tensor) + elif tensor_format == "pt": + return torch.log(tensor) + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]): + """ + Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. + + Args: + values: an array or tensor of values to extract. + broadcast_array: an array with a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + Returns: + a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + + tensor_format = getattr(self, "tensor_format", "pt") + values = values.flatten() + + while len(values.shape) < len(broadcast_array.shape): + values = values[..., None] + if tensor_format == "pt": + values = values.to(broadcast_array.device) + + return values + + def norm(self, tensor): + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + return np.linalg.norm(tensor) + elif tensor_format == "pt": + return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean() + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def randn_like(self, tensor, generator=None): + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + return np.random.randn(*np.shape(tensor)) + elif tensor_format == "pt": + # return torch.randn_like(tensor) + return torch.randn(tensor.shape, layout=tensor.layout, generator=generator).to(tensor.device) + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def zeros_like(self, tensor): + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + return np.zeros_like(tensor) + elif tensor_format == "pt": + return torch.zeros_like(tensor) + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") diff --git a/local_diffusers/testing_utils.py b/local_diffusers/testing_utils.py new file mode 100644 index 000000000..ff8b6aa9b --- /dev/null +++ b/local_diffusers/testing_utils.py @@ -0,0 +1,61 @@ +import os +import random +import unittest +from distutils.util import strtobool + +import torch + +from packaging import version + + +global_rng = random.Random() +torch_device = "cuda" if torch.cuda.is_available() else "cpu" +is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12") + +if is_torch_higher_equal_than_1_12: + torch_device = "mps" if torch.backends.mps.is_available() else torch_device + + +def parse_flag_from_env(key, default=False): + try: + value = os.environ[key] + except KeyError: + # KEY isn't set, default to `default`. + _value = default + else: + # KEY is set, convert it to True or False. + try: + _value = strtobool(value) + except ValueError: + # More values are supported, but let's keep the message simple. + raise ValueError(f"If set, {key} must be yes or no.") + return _value + + +_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) + + +def floats_tensor(shape, scale=1.0, rng=None, name=None): + """Creates a random float32 tensor""" + if rng is None: + rng = global_rng + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.random() * scale) + + return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() + + +def slow(test_case): + """ + Decorator marking a test as slow. + + Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) diff --git a/local_diffusers/training_utils.py b/local_diffusers/training_utils.py new file mode 100644 index 000000000..fa1694161 --- /dev/null +++ b/local_diffusers/training_utils.py @@ -0,0 +1,125 @@ +import copy +import os +import random + +import numpy as np +import torch + + +def enable_full_determinism(seed: int): + """ + Helper function for reproducible behavior during distributed training. See + - https://pytorch.org/docs/stable/notes/randomness.html for pytorch + """ + # set seed first + set_seed(seed) + + # Enable PyTorch deterministic mode. This potentially requires either the environment + # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, + # depending on the CUDA version, so we set them both here + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True) + + # Enable CUDNN deterministic mode + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def set_seed(seed: int): + """ + Args: + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + seed (`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # ^^ safe to call this function even if cuda is not available + + +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__( + self, + model, + update_after_step=0, + inv_gamma=1.0, + power=2 / 3, + min_value=0.0, + max_value=0.9999, + device=None, + ): + """ + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 2/3. + min_value (float): The minimum EMA decay rate. Default: 0. + """ + + self.averaged_model = copy.deepcopy(model).eval() + self.averaged_model.requires_grad_(False) + + self.update_after_step = update_after_step + self.inv_gamma = inv_gamma + self.power = power + self.min_value = min_value + self.max_value = max_value + + if device is not None: + self.averaged_model = self.averaged_model.to(device=device) + + self.decay = 0.0 + self.optimization_step = 0 + + def get_decay(self, optimization_step): + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + value = 1 - (1 + step / self.inv_gamma) ** -self.power + + if step <= 0: + return 0.0 + + return max(self.min_value, min(value, self.max_value)) + + @torch.no_grad() + def step(self, new_model): + ema_state_dict = {} + ema_params = self.averaged_model.state_dict() + + self.decay = self.get_decay(self.optimization_step) + + for key, param in new_model.named_parameters(): + if isinstance(param, dict): + continue + try: + ema_param = ema_params[key] + except KeyError: + ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) + ema_params[key] = ema_param + + if not param.requires_grad: + ema_params[key].copy_(param.to(dtype=ema_param.dtype).data) + ema_param = ema_params[key] + else: + ema_param.mul_(self.decay) + ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) + + ema_state_dict[key] = ema_param + + for key, param in new_model.named_buffers(): + ema_state_dict[key] = param + + self.averaged_model.load_state_dict(ema_state_dict, strict=False) + self.optimization_step += 1 diff --git a/local_diffusers/utils/__init__.py b/local_diffusers/utils/__init__.py new file mode 100644 index 000000000..c00a28e10 --- /dev/null +++ b/local_diffusers/utils/__init__.py @@ -0,0 +1,53 @@ +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +from .import_utils import ( + ENV_VARS_TRUE_AND_AUTO_VALUES, + ENV_VARS_TRUE_VALUES, + USE_JAX, + USE_TF, + USE_TORCH, + DummyObject, + is_flax_available, + is_inflect_available, + is_modelcards_available, + is_onnx_available, + is_scipy_available, + is_tf_available, + is_torch_available, + is_transformers_available, + is_unidecode_available, + requires_backends, +) +from .logging import get_logger +from .outputs import BaseOutput + + +logger = get_logger(__name__) + + +hf_cache_home = os.path.expanduser( + os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) +) +default_cache_path = os.path.join(hf_cache_home, "diffusers") + + +CONFIG_NAME = "config.json" +HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" +DIFFUSERS_CACHE = default_cache_path +DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" +HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) diff --git a/local_diffusers/utils/dummy_scipy_objects.py b/local_diffusers/utils/dummy_scipy_objects.py new file mode 100644 index 000000000..3706c5754 --- /dev/null +++ b/local_diffusers/utils/dummy_scipy_objects.py @@ -0,0 +1,11 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class LMSDiscreteScheduler(metaclass=DummyObject): + _backends = ["scipy"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["scipy"]) diff --git a/local_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py b/local_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py new file mode 100644 index 000000000..8c2aec218 --- /dev/null +++ b/local_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py @@ -0,0 +1,10 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa +from ..utils import DummyObject, requires_backends + + +class GradTTSPipeline(metaclass=DummyObject): + _backends = ["transformers", "inflect", "unidecode"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers", "inflect", "unidecode"]) diff --git a/local_diffusers/utils/dummy_transformers_and_onnx_objects.py b/local_diffusers/utils/dummy_transformers_and_onnx_objects.py new file mode 100644 index 000000000..2e34b5ce0 --- /dev/null +++ b/local_diffusers/utils/dummy_transformers_and_onnx_objects.py @@ -0,0 +1,11 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class StableDiffusionOnnxPipeline(metaclass=DummyObject): + _backends = ["transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers", "onnx"]) diff --git a/local_diffusers/utils/dummy_transformers_objects.py b/local_diffusers/utils/dummy_transformers_objects.py new file mode 100644 index 000000000..e05eb814d --- /dev/null +++ b/local_diffusers/utils/dummy_transformers_objects.py @@ -0,0 +1,32 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class LDMTextToImagePipeline(metaclass=DummyObject): + _backends = ["transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers"]) + + +class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers"]) + + +class StableDiffusionInpaintPipeline(metaclass=DummyObject): + _backends = ["transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers"]) + + +class StableDiffusionPipeline(metaclass=DummyObject): + _backends = ["transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers"]) diff --git a/local_diffusers/utils/import_utils.py b/local_diffusers/utils/import_utils.py new file mode 100644 index 000000000..1f5e95ada --- /dev/null +++ b/local_diffusers/utils/import_utils.py @@ -0,0 +1,274 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Import utilities: Utilities related to imports and our lazy inits. +""" +import importlib.util +import os +import sys +from collections import OrderedDict + +from packaging import version + +from . import logging + + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + +USE_TF = os.environ.get("USE_TF", "AUTO").upper() +USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() +USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() + +_torch_version = "N/A" +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + _torch_available = importlib.util.find_spec("torch") is not None + if _torch_available: + try: + _torch_version = importlib_metadata.version("torch") + logger.info(f"PyTorch version {_torch_version} available.") + except importlib_metadata.PackageNotFoundError: + _torch_available = False +else: + logger.info("Disabling PyTorch because USE_TF is set") + _torch_available = False + + +_tf_version = "N/A" +if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: + _tf_available = importlib.util.find_spec("tensorflow") is not None + if _tf_available: + candidates = ( + "tensorflow", + "tensorflow-cpu", + "tensorflow-gpu", + "tf-nightly", + "tf-nightly-cpu", + "tf-nightly-gpu", + "intel-tensorflow", + "intel-tensorflow-avx512", + "tensorflow-rocm", + "tensorflow-macos", + "tensorflow-aarch64", + ) + _tf_version = None + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + for pkg in candidates: + try: + _tf_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _tf_available = _tf_version is not None + if _tf_available: + if version.parse(_tf_version) < version.parse("2"): + logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.") + _tf_available = False + else: + logger.info(f"TensorFlow version {_tf_version} available.") +else: + logger.info("Disabling Tensorflow because USE_TORCH is set") + _tf_available = False + + +if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: + _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None + if _flax_available: + try: + _jax_version = importlib_metadata.version("jax") + _flax_version = importlib_metadata.version("flax") + logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") + except importlib_metadata.PackageNotFoundError: + _flax_available = False +else: + _flax_available = False + + +_transformers_available = importlib.util.find_spec("transformers") is not None +try: + _transformers_version = importlib_metadata.version("transformers") + logger.debug(f"Successfully imported transformers version {_transformers_version}") +except importlib_metadata.PackageNotFoundError: + _transformers_available = False + + +_inflect_available = importlib.util.find_spec("inflect") is not None +try: + _inflect_version = importlib_metadata.version("inflect") + logger.debug(f"Successfully imported inflect version {_inflect_version}") +except importlib_metadata.PackageNotFoundError: + _inflect_available = False + + +_unidecode_available = importlib.util.find_spec("unidecode") is not None +try: + _unidecode_version = importlib_metadata.version("unidecode") + logger.debug(f"Successfully imported unidecode version {_unidecode_version}") +except importlib_metadata.PackageNotFoundError: + _unidecode_available = False + + +_modelcards_available = importlib.util.find_spec("modelcards") is not None +try: + _modelcards_version = importlib_metadata.version("modelcards") + logger.debug(f"Successfully imported modelcards version {_modelcards_version}") +except importlib_metadata.PackageNotFoundError: + _modelcards_available = False + + +_onnx_available = importlib.util.find_spec("onnxruntime") is not None +try: + _onnxruntime_version = importlib_metadata.version("onnxruntime") + logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") +except importlib_metadata.PackageNotFoundError: + _onnx_available = False + + +_scipy_available = importlib.util.find_spec("scipy") is not None +try: + _scipy_version = importlib_metadata.version("scipy") + logger.debug(f"Successfully imported transformers version {_scipy_version}") +except importlib_metadata.PackageNotFoundError: + _scipy_available = False + + +def is_torch_available(): + return _torch_available + + +def is_tf_available(): + return _tf_available + + +def is_flax_available(): + return _flax_available + + +def is_transformers_available(): + return _transformers_available + + +def is_inflect_available(): + return _inflect_available + + +def is_unidecode_available(): + return _unidecode_available + + +def is_modelcards_available(): + return _modelcards_available + + +def is_onnx_available(): + return _onnx_available + + +def is_scipy_available(): + return _scipy_available + + +# docstyle-ignore +FLAX_IMPORT_ERROR = """ +{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the +installation page: https://github.com/google/flax and follow the ones that match your environment. +""" + +# docstyle-ignore +INFLECT_IMPORT_ERROR = """ +{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install +inflect` +""" + +# docstyle-ignore +PYTORCH_IMPORT_ERROR = """ +{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the +installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. +""" + +# docstyle-ignore +ONNX_IMPORT_ERROR = """ +{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip +install onnxruntime` +""" + +# docstyle-ignore +SCIPY_IMPORT_ERROR = """ +{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install +scipy` +""" + +# docstyle-ignore +TENSORFLOW_IMPORT_ERROR = """ +{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the +installation page: https://www.tensorflow.org/install and follow the ones that match your environment. +""" + +# docstyle-ignore +TRANSFORMERS_IMPORT_ERROR = """ +{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip +install transformers` +""" + +# docstyle-ignore +UNIDECODE_IMPORT_ERROR = """ +{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install +Unidecode` +""" + + +BACKENDS_MAPPING = OrderedDict( + [ + ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), + ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), + ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), + ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), + ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), + ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), + ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), + ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), + ] +) + + +def requires_backends(obj, backends): + if not isinstance(backends, (list, tuple)): + backends = [backends] + + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + checks = (BACKENDS_MAPPING[backend] for backend in backends) + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: + raise ImportError("".join(failed)) + + +class DummyObject(type): + """ + Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by + `requires_backend` each time a user tries to access any method of that class. + """ + + def __getattr__(cls, key): + if key.startswith("_"): + return super().__getattr__(cls, key) + requires_backends(cls, cls._backends) diff --git a/local_diffusers/utils/logging.py b/local_diffusers/utils/logging.py new file mode 100644 index 000000000..1f2d0227b --- /dev/null +++ b/local_diffusers/utils/logging.py @@ -0,0 +1,344 @@ +# coding=utf-8 +# Copyright 2020 Optuna, Hugging Face +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Logging utilities.""" + +import logging +import os +import sys +import threading +from logging import CRITICAL # NOQA +from logging import DEBUG # NOQA +from logging import ERROR # NOQA +from logging import FATAL # NOQA +from logging import INFO # NOQA +from logging import NOTSET # NOQA +from logging import WARN # NOQA +from logging import WARNING # NOQA +from typing import Optional + +from tqdm import auto as tqdm_lib + + +_lock = threading.Lock() +_default_handler: Optional[logging.Handler] = None + +log_levels = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +_default_log_level = logging.WARNING + +_tqdm_active = True + + +def _get_default_logging_level(): + """ + If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is + not - fall back to `_default_log_level` + """ + env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None) + if env_level_str: + if env_level_str in log_levels: + return log_levels[env_level_str] + else: + logging.getLogger().warning( + f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " + f"has to be one of: { ', '.join(log_levels.keys()) }" + ) + return _default_log_level + + +def _get_library_name() -> str: + + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_get_default_logging_level()) + library_root_logger.propagate = False + + +def _reset_library_root_logger() -> None: + + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def get_log_levels_dict(): + return log_levels + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Return a logger with the specified name. + + This function is not supposed to be directly accessed unless you are writing a custom diffusers module. + """ + + if name is None: + name = _get_library_name() + + _configure_library_root_logger() + return logging.getLogger(name) + + +def get_verbosity() -> int: + """ + Return the current level for the 🤗 Diffusers' root logger as an int. + + Returns: + `int`: The logging level. + + + + 🤗 Diffusers has following logging levels: + + - 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - 40: `diffusers.logging.ERROR` + - 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - 20: `diffusers.logging.INFO` + - 10: `diffusers.logging.DEBUG` + + """ + + _configure_library_root_logger() + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """ + Set the verbosity level for the 🤗 Diffusers' root logger. + + Args: + verbosity (`int`): + Logging level, e.g., one of: + + - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - `diffusers.logging.ERROR` + - `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - `diffusers.logging.INFO` + - `diffusers.logging.DEBUG` + """ + + _configure_library_root_logger() + _get_library_root_logger().setLevel(verbosity) + + +def set_verbosity_info(): + """Set the verbosity to the `INFO` level.""" + return set_verbosity(INFO) + + +def set_verbosity_warning(): + """Set the verbosity to the `WARNING` level.""" + return set_verbosity(WARNING) + + +def set_verbosity_debug(): + """Set the verbosity to the `DEBUG` level.""" + return set_verbosity(DEBUG) + + +def set_verbosity_error(): + """Set the verbosity to the `ERROR` level.""" + return set_verbosity(ERROR) + + +def disable_default_handler() -> None: + """Disable the default handler of the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().removeHandler(_default_handler) + + +def enable_default_handler() -> None: + """Enable the default handler of the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().addHandler(_default_handler) + + +def add_handler(handler: logging.Handler) -> None: + """adds a handler to the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert handler is not None + _get_library_root_logger().addHandler(handler) + + +def remove_handler(handler: logging.Handler) -> None: + """removes given handler from the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert handler is not None and handler not in _get_library_root_logger().handlers + _get_library_root_logger().removeHandler(handler) + + +def disable_propagation() -> None: + """ + Disable propagation of the library log outputs. Note that log propagation is disabled by default. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """ + Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent + double logging if the root logger has been configured. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = True + + +def enable_explicit_format() -> None: + """ + Enable explicit formatting for every HuggingFace Diffusers' logger. The explicit formatter is as follows: + ``` + [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE + ``` + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") + handler.setFormatter(formatter) + + +def reset_format() -> None: + """ + Resets the formatting for HuggingFace Diffusers' loggers. + + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + handler.setFormatter(None) + + +def warning_advice(self, *args, **kwargs): + """ + This method is identical to `logger.warninging()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this + warning will not be printed + """ + no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False) + if no_advisory_warnings: + return + self.warning(*args, **kwargs) + + +logging.Logger.warning_advice = warning_advice + + +class EmptyTqdm: + """Dummy tqdm which doesn't do anything.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + self._iterator = args[0] if args else None + + def __iter__(self): + return iter(self._iterator) + + def __getattr__(self, _): + """Return empty function.""" + + def empty_fn(*args, **kwargs): # pylint: disable=unused-argument + return + + return empty_fn + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + return + + +class _tqdm_cls: + def __call__(self, *args, **kwargs): + if _tqdm_active: + return tqdm_lib.tqdm(*args, **kwargs) + else: + return EmptyTqdm(*args, **kwargs) + + def set_lock(self, *args, **kwargs): + self._lock = None + if _tqdm_active: + return tqdm_lib.tqdm.set_lock(*args, **kwargs) + + def get_lock(self): + if _tqdm_active: + return tqdm_lib.tqdm.get_lock() + + +tqdm = _tqdm_cls() + + +def is_progress_bar_enabled() -> bool: + """Return a boolean indicating whether tqdm progress bars are enabled.""" + global _tqdm_active + return bool(_tqdm_active) + + +def enable_progress_bar(): + """Enable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = True + + +def disable_progress_bar(): + """Disable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = False diff --git a/local_diffusers/utils/outputs.py b/local_diffusers/utils/outputs.py new file mode 100644 index 000000000..b02f62d02 --- /dev/null +++ b/local_diffusers/utils/outputs.py @@ -0,0 +1,109 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Generic utilities +""" + +import warnings +from collections import OrderedDict +from dataclasses import fields +from typing import Any, Tuple + +import numpy as np + +from .import_utils import is_torch_available + + +def is_tensor(x): + """ + Tests if `x` is a `torch.Tensor` or `np.ndarray`. + """ + if is_torch_available(): + import torch + + if isinstance(x, torch.Tensor): + return True + + return isinstance(x, np.ndarray) + + +class BaseOutput(OrderedDict): + """ + Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a + tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular + python dictionary. + + + + You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple + before. + + + """ + + def __post_init__(self): + class_fields = fields(self) + + # Safety and consistency checks + if not len(class_fields): + raise ValueError(f"{self.__class__.__name__} has no fields.") + + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = {k: v for (k, v) in self.items()} + if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample": + warnings.warn( + "The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or" + " `'images'` instead.", + DeprecationWarning, + ) + return inner_dict["images"] + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + + def to_tuple(self) -> Tuple[Any]: + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(self[k] for k in self.keys()) From bfebf01b6b66a91d219e34c9cb93df5e01022e38 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Sun, 2 Oct 2022 08:57:31 +0200 Subject: [PATCH 58/76] fix --- README.md | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/README.md b/README.md index 7525218ff..343fed800 100644 --- a/README.md +++ b/README.md @@ -37,33 +37,6 @@ pip install git+https://github.com/sberbank-ai/Real-ESRGAN.git wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P weights ``` -## Hack diffusers (yes I should do that differently... only solution for now). -Copy the file "pipeline_stable_diffusion.py" in lieu of the original pipeline_stable_diffusion.py. - -How to do this ? - First, find where ``diffusers'' is: -``` - python -c "import diffusers ; print(diffusers.__file__)" -``` -and pipeline_stable_diffusion should be copied at this location + "/pipelines/stable_diffusion/pipeline_stable_diffusion.py" (overwrite that file). - -Or inside python -``` -import diffusers -print(diffusers.__file__) -``` - - Then copy the local file there as follows: -``` -cp pipeline_stable_diffusion.py <>/pipeline_stable_diffusion.py -``` -You can also do a symbolic link: -``` -pushd <> -mv pipeline_stable_diffusion.py backup_pipeline_stable_diffusion.py -ln -s <>/pipeline_stable_diffusion.py . -``` - ## Then run << python minisd.py >>. You should be asked for a prompt (just <> if you like the proposed hardcoded prompt), and then a window should be opened. From a30527f8428ec21170f3fc883f79607b1bd689f6 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Sun, 2 Oct 2022 11:47:32 +0200 Subject: [PATCH 59/76] fix --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 343fed800..b89cfc055 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,10 @@ Ping us at the Nevergrad user group if you need help, I'll do my best. +## Get a HuggingFace token! This is a fork of HuggingFace's stablediffusion. +Just click here and copy-paste your token: +[**Hugging face tokens**](https://huggingface.co/login?next=%2Fsettings%2Ftokens)
## Install StableDiffusion as usual, plus a few more stuff. Basically: From 47013929d8da5854c9b9e3d74bbf6d7a2cfd32b5 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Sun, 2 Oct 2022 16:54:33 +0200 Subject: [PATCH 60/76] fix --- diffusers/__init__.py | 60 + diffusers/commands/__init__.py | 27 + diffusers/commands/diffusers_cli.py | 41 + diffusers/commands/env.py | 70 + diffusers/configuration_utils.py | 403 +++++ diffusers/dependency_versions_check.py | 47 + diffusers/dependency_versions_table.py | 26 + diffusers/dynamic_modules_utils.py | 335 ++++ diffusers/hub_utils.py | 197 +++ diffusers/modeling_utils.py | 542 ++++++ diffusers/models/__init__.py | 17 + diffusers/models/attention.py | 333 ++++ diffusers/models/embeddings.py | 115 ++ diffusers/models/resnet.py | 483 ++++++ diffusers/models/unet_2d.py | 246 +++ diffusers/models/unet_2d_condition.py | 270 +++ diffusers/models/unet_blocks.py | 1481 +++++++++++++++++ diffusers/models/vae.py | 581 +++++++ diffusers/onnx_utils.py | 189 +++ diffusers/optimization.py | 275 +++ diffusers/pipeline_utils.py | 417 +++++ diffusers/pipelines/__init__.py | 19 + diffusers/pipelines/ddim/__init__.py | 2 + diffusers/pipelines/ddim/pipeline_ddim.py | 117 ++ diffusers/pipelines/ddpm/__init__.py | 2 + diffusers/pipelines/ddpm/pipeline_ddpm.py | 106 ++ .../pipelines/latent_diffusion/__init__.py | 6 + .../pipeline_latent_diffusion.py | 705 ++++++++ .../latent_diffusion_uncond/__init__.py | 2 + .../pipeline_latent_diffusion_uncond.py | 108 ++ diffusers/pipelines/pndm/__init__.py | 2 + diffusers/pipelines/pndm/pipeline_pndm.py | 111 ++ diffusers/pipelines/score_sde_ve/__init__.py | 2 + .../score_sde_ve/pipeline_score_sde_ve.py | 101 ++ .../pipelines/stable_diffusion/__init__.py | 37 + .../pipeline_stable_diffusion.py | 398 +++++ .../pipeline_stable_diffusion_img2img.py | 291 ++++ .../pipeline_stable_diffusion_inpaint.py | 309 ++++ .../pipeline_stable_diffusion_onnx.py | 165 ++ .../stable_diffusion/safety_checker.py | 106 ++ .../stochastic_karras_ve/__init__.py | 2 + .../pipeline_stochastic_karras_ve.py | 129 ++ diffusers/schedulers/__init__.py | 28 + diffusers/schedulers/scheduling_ddim.py | 261 +++ diffusers/schedulers/scheduling_ddpm.py | 264 +++ diffusers/schedulers/scheduling_karras_ve.py | 208 +++ .../schedulers/scheduling_lms_discrete.py | 193 +++ diffusers/schedulers/scheduling_pndm.py | 378 +++++ diffusers/schedulers/scheduling_sde_ve.py | 283 ++++ diffusers/schedulers/scheduling_sde_vp.py | 81 + diffusers/schedulers/scheduling_utils.py | 125 ++ diffusers/testing_utils.py | 61 + diffusers/training_utils.py | 125 ++ diffusers/utils/__init__.py | 53 + diffusers/utils/dummy_scipy_objects.py | 11 + ...rmers_and_inflect_and_unidecode_objects.py | 10 + .../dummy_transformers_and_onnx_objects.py | 11 + diffusers/utils/dummy_transformers_objects.py | 32 + diffusers/utils/import_utils.py | 274 +++ diffusers/utils/logging.py | 344 ++++ diffusers/utils/outputs.py | 109 ++ 61 files changed, 11726 insertions(+) create mode 100644 diffusers/__init__.py create mode 100644 diffusers/commands/__init__.py create mode 100644 diffusers/commands/diffusers_cli.py create mode 100644 diffusers/commands/env.py create mode 100644 diffusers/configuration_utils.py create mode 100644 diffusers/dependency_versions_check.py create mode 100644 diffusers/dependency_versions_table.py create mode 100644 diffusers/dynamic_modules_utils.py create mode 100644 diffusers/hub_utils.py create mode 100644 diffusers/modeling_utils.py create mode 100644 diffusers/models/__init__.py create mode 100644 diffusers/models/attention.py create mode 100644 diffusers/models/embeddings.py create mode 100644 diffusers/models/resnet.py create mode 100644 diffusers/models/unet_2d.py create mode 100644 diffusers/models/unet_2d_condition.py create mode 100644 diffusers/models/unet_blocks.py create mode 100644 diffusers/models/vae.py create mode 100644 diffusers/onnx_utils.py create mode 100644 diffusers/optimization.py create mode 100644 diffusers/pipeline_utils.py create mode 100644 diffusers/pipelines/__init__.py create mode 100644 diffusers/pipelines/ddim/__init__.py create mode 100644 diffusers/pipelines/ddim/pipeline_ddim.py create mode 100644 diffusers/pipelines/ddpm/__init__.py create mode 100644 diffusers/pipelines/ddpm/pipeline_ddpm.py create mode 100644 diffusers/pipelines/latent_diffusion/__init__.py create mode 100644 diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py create mode 100644 diffusers/pipelines/latent_diffusion_uncond/__init__.py create mode 100644 diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py create mode 100644 diffusers/pipelines/pndm/__init__.py create mode 100644 diffusers/pipelines/pndm/pipeline_pndm.py create mode 100644 diffusers/pipelines/score_sde_ve/__init__.py create mode 100644 diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py create mode 100644 diffusers/pipelines/stable_diffusion/__init__.py create mode 100644 diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py create mode 100644 diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py create mode 100644 diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py create mode 100644 diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py create mode 100644 diffusers/pipelines/stable_diffusion/safety_checker.py create mode 100644 diffusers/pipelines/stochastic_karras_ve/__init__.py create mode 100644 diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py create mode 100644 diffusers/schedulers/__init__.py create mode 100644 diffusers/schedulers/scheduling_ddim.py create mode 100644 diffusers/schedulers/scheduling_ddpm.py create mode 100644 diffusers/schedulers/scheduling_karras_ve.py create mode 100644 diffusers/schedulers/scheduling_lms_discrete.py create mode 100644 diffusers/schedulers/scheduling_pndm.py create mode 100644 diffusers/schedulers/scheduling_sde_ve.py create mode 100644 diffusers/schedulers/scheduling_sde_vp.py create mode 100644 diffusers/schedulers/scheduling_utils.py create mode 100644 diffusers/testing_utils.py create mode 100644 diffusers/training_utils.py create mode 100644 diffusers/utils/__init__.py create mode 100644 diffusers/utils/dummy_scipy_objects.py create mode 100644 diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py create mode 100644 diffusers/utils/dummy_transformers_and_onnx_objects.py create mode 100644 diffusers/utils/dummy_transformers_objects.py create mode 100644 diffusers/utils/import_utils.py create mode 100644 diffusers/utils/logging.py create mode 100644 diffusers/utils/outputs.py diff --git a/diffusers/__init__.py b/diffusers/__init__.py new file mode 100644 index 000000000..bf2f183c9 --- /dev/null +++ b/diffusers/__init__.py @@ -0,0 +1,60 @@ +from .utils import ( + is_inflect_available, + is_onnx_available, + is_scipy_available, + is_transformers_available, + is_unidecode_available, +) + + +__version__ = "0.3.0" + +from .configuration_utils import ConfigMixin +from .modeling_utils import ModelMixin +from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel +from .onnx_utils import OnnxRuntimeModel +from .optimization import ( + get_constant_schedule, + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, + get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, + get_scheduler, +) +from .pipeline_utils import DiffusionPipeline +from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline +from .schedulers import ( + DDIMScheduler, + DDPMScheduler, + KarrasVeScheduler, + PNDMScheduler, + SchedulerMixin, + ScoreSdeVeScheduler, +) +from .utils import logging + + +if is_scipy_available(): + from .schedulers import LMSDiscreteScheduler +else: + from .utils.dummy_scipy_objects import * # noqa F403 + +from .training_utils import EMAModel + + +if is_transformers_available(): + from .pipelines import ( + LDMTextToImagePipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + ) +else: + from .utils.dummy_transformers_objects import * # noqa F403 + + +if is_transformers_available() and is_onnx_available(): + from .pipelines import StableDiffusionOnnxPipeline +else: + from .utils.dummy_transformers_and_onnx_objects import * # noqa F403 diff --git a/diffusers/commands/__init__.py b/diffusers/commands/__init__.py new file mode 100644 index 000000000..902bd46ce --- /dev/null +++ b/diffusers/commands/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from argparse import ArgumentParser + + +class BaseDiffusersCLICommand(ABC): + @staticmethod + @abstractmethod + def register_subcommand(parser: ArgumentParser): + raise NotImplementedError() + + @abstractmethod + def run(self): + raise NotImplementedError() diff --git a/diffusers/commands/diffusers_cli.py b/diffusers/commands/diffusers_cli.py new file mode 100644 index 000000000..30084e55b --- /dev/null +++ b/diffusers/commands/diffusers_cli.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser + +from .env import EnvironmentCommand + + +def main(): + parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []") + commands_parser = parser.add_subparsers(help="diffusers-cli command helpers") + + # Register commands + EnvironmentCommand.register_subcommand(commands_parser) + + # Let's go + args = parser.parse_args() + + if not hasattr(args, "func"): + parser.print_help() + exit(1) + + # Run + service = args.func(args) + service.run() + + +if __name__ == "__main__": + main() diff --git a/diffusers/commands/env.py b/diffusers/commands/env.py new file mode 100644 index 000000000..81a878bff --- /dev/null +++ b/diffusers/commands/env.py @@ -0,0 +1,70 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform +from argparse import ArgumentParser + +import huggingface_hub + +from .. import __version__ as version +from ..utils import is_torch_available, is_transformers_available +from . import BaseDiffusersCLICommand + + +def info_command_factory(_): + return EnvironmentCommand() + + +class EnvironmentCommand(BaseDiffusersCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + download_parser = parser.add_parser("env") + download_parser.set_defaults(func=info_command_factory) + + def run(self): + hub_version = huggingface_hub.__version__ + + pt_version = "not installed" + pt_cuda_available = "NA" + if is_torch_available(): + import torch + + pt_version = torch.__version__ + pt_cuda_available = torch.cuda.is_available() + + transformers_version = "not installed" + if is_transformers_available: + import transformers + + transformers_version = transformers.__version__ + + info = { + "`diffusers` version": version, + "Platform": platform.platform(), + "Python version": platform.python_version(), + "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", + "Huggingface_hub version": hub_version, + "Transformers version": transformers_version, + "Using GPU in script?": "", + "Using distributed or parallel set-up in script?": "", + } + + print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") + print(self.format_dict(info)) + + return info + + @staticmethod + def format_dict(d): + return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" diff --git a/diffusers/configuration_utils.py b/diffusers/configuration_utils.py new file mode 100644 index 000000000..fbe75f3f1 --- /dev/null +++ b/diffusers/configuration_utils.py @@ -0,0 +1,403 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" ConfigMixinuration base class and utilities.""" +import functools +import inspect +import json +import os +import re +from collections import OrderedDict +from typing import Any, Dict, Tuple, Union + +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from requests import HTTPError + +from . import __version__ +from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging + + +logger = logging.get_logger(__name__) + +_re_configuration_file = re.compile(r"config\.(.*)\.json") + + +class ConfigMixin: + r""" + Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all + methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with + - [`~ConfigMixin.from_config`] + - [`~ConfigMixin.save_config`] + + Class attributes: + - **config_name** (`str`) -- A filename under which the config should stored when calling + [`~ConfigMixin.save_config`] (should be overriden by parent class). + - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be + overriden by parent class). + """ + config_name = None + ignore_for_config = [] + + def register_to_config(self, **kwargs): + if self.config_name is None: + raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") + kwargs["_class_name"] = self.__class__.__name__ + kwargs["_diffusers_version"] = __version__ + + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + previous_dict = dict(self._internal_dict) + internal_dict = {**self._internal_dict, **kwargs} + logger.debug(f"Updating config from {previous_dict} to {internal_dict}") + + self._internal_dict = FrozenDict(internal_dict) + + def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~ConfigMixin.from_config`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + # If we save using the predefined names, we can load using `from_config` + output_config_file = os.path.join(save_directory, self.config_name) + + self.to_json_file(output_config_file) + logger.info(f"ConfigMixinuration saved in {output_config_file}") + + @classmethod + def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): + r""" + Instantiate a Python class from a pre-defined JSON-file. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + Passing `use_auth_token=True`` is required when you want to use a private model. + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ + config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + + init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) + + model = cls(**init_dict) + + if return_unused_kwargs: + return model, unused_kwargs + else: + return model + + @classmethod + def get_config_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + + user_agent = {"file_type": "config"} + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + + if cls.config_name is None: + raise ValueError( + "`self.config_name` is not defined. Note that one should not load a config from " + "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" + ) + + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): + # Load from a PyTorch checkpoint + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + ): + config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + else: + raise EnvironmentError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + config_file = hf_hub_download( + pretrained_model_name_or_path, + filename=cls.config_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" + " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" + " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli" + " login` and pass `use_auth_token=True`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for" + " this model name. Check the model page at" + f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}." + ) + except HTTPError as err: + raise EnvironmentError( + "There was a specific connection error when trying to load" + f" {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to" + " run the library in offline mode at" + " 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a {cls.config_name} file" + ) + + try: + # Load config dict + config_dict = cls._dict_from_json_file(config_file) + except (json.JSONDecodeError, UnicodeDecodeError): + raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") + + return config_dict + + @classmethod + def extract_init_dict(cls, config_dict, **kwargs): + expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) + expected_keys.remove("self") + # remove general kwargs if present in dict + if "kwargs" in expected_keys: + expected_keys.remove("kwargs") + # remove keys to be ignored + if len(cls.ignore_for_config) > 0: + expected_keys = expected_keys - set(cls.ignore_for_config) + init_dict = {} + for key in expected_keys: + if key in kwargs: + # overwrite key + init_dict[key] = kwargs.pop(key) + elif key in config_dict: + # use value from config dict + init_dict[key] = config_dict.pop(key) + + unused_kwargs = config_dict.update(kwargs) + + passed_keys = set(init_dict.keys()) + if len(expected_keys - passed_keys) > 0: + logger.warning( + f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." + ) + + return init_dict, unused_kwargs + + @classmethod + def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @property + def config(self) -> Dict[str, Any]: + return self._internal_dict + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this configuration instance in JSON format. + """ + config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __setattr__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setattr__(name, value) + + def __setitem__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setitem__(name, value) + + +def register_to_config(init): + r""" + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable + + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self, *args, **kwargs): + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + init(self, *args, **init_kwargs) + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + ignore = getattr(self, "ignore_for_config", []) + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + getattr(self, "register_to_config")(**new_kwargs) + + return inner_init diff --git a/diffusers/dependency_versions_check.py b/diffusers/dependency_versions_check.py new file mode 100644 index 000000000..bbf863222 --- /dev/null +++ b/diffusers/dependency_versions_check.py @@ -0,0 +1,47 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +from .dependency_versions_table import deps +from .utils.versions import require_version, require_version_core + + +# define which module versions we always want to check at run time +# (usually the ones defined in `install_requires` in setup.py) +# +# order specific notes: +# - tqdm must be checked before tokenizers + +pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split() +if sys.version_info < (3, 7): + pkgs_to_check_at_runtime.append("dataclasses") +if sys.version_info < (3, 8): + pkgs_to_check_at_runtime.append("importlib_metadata") + +for pkg in pkgs_to_check_at_runtime: + if pkg in deps: + if pkg == "tokenizers": + # must be loaded here, or else tqdm check may fail + from .utils import is_tokenizers_available + + if not is_tokenizers_available(): + continue # not required, check version only if installed + + require_version_core(deps[pkg]) + else: + raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") + + +def dep_version_check(pkg, hint=None): + require_version(deps[pkg], hint) diff --git a/diffusers/dependency_versions_table.py b/diffusers/dependency_versions_table.py new file mode 100644 index 000000000..74c5331e5 --- /dev/null +++ b/diffusers/dependency_versions_table.py @@ -0,0 +1,26 @@ +# THIS FILE HAS BEEN AUTOGENERATED. To update: +# 1. modify the `_deps` dict in setup.py +# 2. run `make deps_table_update`` +deps = { + "Pillow": "Pillow", + "accelerate": "accelerate>=0.11.0", + "black": "black==22.3", + "datasets": "datasets", + "filelock": "filelock", + "flake8": "flake8>=3.8.3", + "hf-doc-builder": "hf-doc-builder>=0.3.0", + "huggingface-hub": "huggingface-hub>=0.8.1", + "importlib_metadata": "importlib_metadata", + "isort": "isort>=5.5.4", + "modelcards": "modelcards==0.1.4", + "numpy": "numpy", + "pytest": "pytest", + "pytest-timeout": "pytest-timeout", + "pytest-xdist": "pytest-xdist", + "scipy": "scipy", + "regex": "regex!=2019.12.17", + "requests": "requests", + "tensorboard": "tensorboard", + "torch": "torch>=1.4", + "transformers": "transformers>=4.21.0", +} diff --git a/diffusers/dynamic_modules_utils.py b/diffusers/dynamic_modules_utils.py new file mode 100644 index 000000000..0ebf916e7 --- /dev/null +++ b/diffusers/dynamic_modules_utils.py @@ -0,0 +1,335 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities to dynamically load objects from the Hub.""" + +import importlib +import os +import re +import shutil +import sys +from pathlib import Path +from typing import Dict, Optional, Union + +from huggingface_hub import cached_download + +from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def init_hf_modules(): + """ + Creates the cache directory for modules with an init, and adds it to the Python path. + """ + # This function has already been executed if HF_MODULES_CACHE already is in the Python path. + if HF_MODULES_CACHE in sys.path: + return + + sys.path.append(HF_MODULES_CACHE) + os.makedirs(HF_MODULES_CACHE, exist_ok=True) + init_path = Path(HF_MODULES_CACHE) / "__init__.py" + if not init_path.exists(): + init_path.touch() + + +def create_dynamic_module(name: Union[str, os.PathLike]): + """ + Creates a dynamic module in the cache directory for modules. + """ + init_hf_modules() + dynamic_module_path = Path(HF_MODULES_CACHE) / name + # If the parent module does not exist yet, recursively create it. + if not dynamic_module_path.parent.exists(): + create_dynamic_module(dynamic_module_path.parent) + os.makedirs(dynamic_module_path, exist_ok=True) + init_path = dynamic_module_path / "__init__.py" + if not init_path.exists(): + init_path.touch() + + +def get_relative_imports(module_file): + """ + Get the list of modules that are relatively imported in a module file. + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + """ + with open(module_file, "r", encoding="utf-8") as f: + content = f.read() + + # Imports of the form `import .xxx` + relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from .xxx import yyy` + relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) + # Unique-ify + return list(set(relative_imports)) + + +def get_relative_import_files(module_file): + """ + Get the list of all files that are needed for a given module. Note that this function recurses through the relative + imports (if a imports b and b imports c, it will return module files for b and c). + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + """ + no_change = False + files_to_check = [module_file] + all_relative_imports = [] + + # Let's recurse through all relative imports + while not no_change: + new_imports = [] + for f in files_to_check: + new_imports.extend(get_relative_imports(f)) + + module_path = Path(module_file).parent + new_import_files = [str(module_path / m) for m in new_imports] + new_import_files = [f for f in new_import_files if f not in all_relative_imports] + files_to_check = [f"{f}.py" for f in new_import_files] + + no_change = len(new_import_files) == 0 + all_relative_imports.extend(files_to_check) + + return all_relative_imports + + +def check_imports(filename): + """ + Check if the current Python environment contains all the libraries that are imported in a file. + """ + with open(filename, "r", encoding="utf-8") as f: + content = f.read() + + # Imports of the form `import xxx` + imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from xxx import yyy` + imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) + # Only keep the top-level module + imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] + + # Unique-ify and test we got them all + imports = list(set(imports)) + missing_packages = [] + for imp in imports: + try: + importlib.import_module(imp) + except ImportError: + missing_packages.append(imp) + + if len(missing_packages) > 0: + raise ImportError( + "This modeling file requires the following packages that were not found in your environment: " + f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" + ) + + return get_relative_imports(filename) + + +def get_class_in_module(class_name, module_path): + """ + Import a module on the cache directory for modules and extract a class from it. + """ + module_path = module_path.replace(os.path.sep, ".") + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def get_cached_module_file( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, +): + """ + Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached + Transformers module. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like `dbmdz/bert-base-german-cased`. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + module_file (`str`): + The name of the module file containing the class to look for. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + Passing `use_auth_token=True` is required when you want to use a private model. + + + + Returns: + `str`: The path to the module inside the cache. + """ + # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) + submodule = "local" + + if os.path.isfile(module_file_or_url): + resolved_module_file = module_file_or_url + else: + try: + # Load from URL or cache if already cached + resolved_module_file = cached_download( + module_file_or_url, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + ) + + except EnvironmentError: + logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") + raise + + # Check we have all the requirements in our environment + modules_needed = check_imports(resolved_module_file) + + # Now we move the module inside our cached dynamic modules. + full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule + create_dynamic_module(full_submodule) + submodule_path = Path(HF_MODULES_CACHE) / full_submodule + # We always copy local files (we could hash the file to see if there was a change, and give them the name of + # that hash, to only copy when there is a modification but it seems overkill for now). + # The only reason we do the copy is to avoid putting too many folders in sys.path. + shutil.copy(resolved_module_file, submodule_path / module_file) + for module_needed in modules_needed: + module_needed = f"{module_needed}.py" + shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) + return os.path.join(full_submodule, module_file) + + +def get_class_from_dynamic_module( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + class_name: str, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + **kwargs, +): + """ + Extracts a class from a module file, present in the local folder or repository of a model. + + + + Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should + therefore only be called on trusted repos. + + + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like `dbmdz/bert-base-german-cased`. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + module_file (`str`): + The name of the module file containing the class to look for. + class_name (`str`): + The name of the class to import in the module. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + Passing `use_auth_token=True` is required when you want to use a private model. + + + + Returns: + `type`: The class, dynamically imported from the module. + + Examples: + + ```python + # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this + # module. + cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") + ```""" + # And lastly we get the class inside our newly created module + final_module = get_cached_module_file( + pretrained_model_name_or_path, + module_file, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + use_auth_token=use_auth_token, + revision=revision, + local_files_only=local_files_only, + ) + return get_class_in_module(class_name, final_module.replace(".py", "")) diff --git a/diffusers/hub_utils.py b/diffusers/hub_utils.py new file mode 100644 index 000000000..c07329e36 --- /dev/null +++ b/diffusers/hub_utils.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil +from pathlib import Path +from typing import Optional + +from huggingface_hub import HfFolder, Repository, whoami + +from .pipeline_utils import DiffusionPipeline +from .utils import is_modelcards_available, logging + + +if is_modelcards_available(): + from modelcards import CardData, ModelCard + + +logger = logging.get_logger(__name__) + + +MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md" + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def init_git_repo(args, at_init: bool = False): + """ + Args: + Initializes a git repo in `args.hub_model_id`. + at_init (`bool`, *optional*, defaults to `False`): + Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True` + and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out. + """ + if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: + return + hub_token = args.hub_token if hasattr(args, "hub_token") else None + use_auth_token = True if hub_token is None else hub_token + if not hasattr(args, "hub_model_id") or args.hub_model_id is None: + repo_name = Path(args.output_dir).absolute().name + else: + repo_name = args.hub_model_id + if "/" not in repo_name: + repo_name = get_full_repo_name(repo_name, token=hub_token) + + try: + repo = Repository( + args.output_dir, + clone_from=repo_name, + use_auth_token=use_auth_token, + private=args.hub_private_repo, + ) + except EnvironmentError: + if args.overwrite_output_dir and at_init: + # Try again after wiping output_dir + shutil.rmtree(args.output_dir) + repo = Repository( + args.output_dir, + clone_from=repo_name, + use_auth_token=use_auth_token, + ) + else: + raise + + repo.git_pull() + + # By default, ignore the checkpoint folders + if not os.path.exists(os.path.join(args.output_dir, ".gitignore")): + with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: + writer.writelines(["checkpoint-*/"]) + + return repo + + +def push_to_hub( + args, + pipeline: DiffusionPipeline, + repo: Repository, + commit_message: Optional[str] = "End of training", + blocking: bool = True, + **kwargs, +) -> str: + """ + Parameters: + Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. + commit_message (`str`, *optional*, defaults to `"End of training"`): + Message to commit while pushing. + blocking (`bool`, *optional*, defaults to `True`): + Whether the function should return only when the `git push` has finished. + kwargs: + Additional keyword arguments passed along to [`create_model_card`]. + Returns: + The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the + commit and an object to track the progress of the commit if `blocking=True` + """ + + if not hasattr(args, "hub_model_id") or args.hub_model_id is None: + model_name = Path(args.output_dir).name + else: + model_name = args.hub_model_id.split("/")[-1] + + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving pipeline checkpoint to {output_dir}") + pipeline.save_pretrained(output_dir) + + # Only push from one node. + if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: + return + + # Cancel any async push in progress if blocking=True. The commits will all be pushed together. + if ( + blocking + and len(repo.command_queue) > 0 + and repo.command_queue[-1] is not None + and not repo.command_queue[-1].is_done + ): + repo.command_queue[-1]._process.kill() + + git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True) + # push separately the model card to be independent from the rest of the model + create_model_card(args, model_name=model_name) + try: + repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True) + except EnvironmentError as exc: + logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") + + return git_head_commit_url + + +def create_model_card(args, model_name): + if not is_modelcards_available: + raise ValueError( + "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can" + " install the package with `pip install modelcards`." + ) + + if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: + return + + hub_token = args.hub_token if hasattr(args, "hub_token") else None + repo_name = get_full_repo_name(model_name, token=hub_token) + + model_card = ModelCard.from_template( + card_data=CardData( # Card metadata object that will be converted to YAML block + language="en", + license="apache-2.0", + library_name="diffusers", + tags=[], + datasets=args.dataset_name, + metrics=[], + ), + template_path=MODEL_CARD_TEMPLATE_PATH, + model_name=model_name, + repo_name=repo_name, + dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None, + learning_rate=args.learning_rate, + train_batch_size=args.train_batch_size, + eval_batch_size=args.eval_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps + if hasattr(args, "gradient_accumulation_steps") + else None, + adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None, + adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None, + adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None, + adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None, + lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None, + lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None, + ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None, + ema_power=args.ema_power if hasattr(args, "ema_power") else None, + ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None, + mixed_precision=args.mixed_precision, + ) + + card_path = os.path.join(args.output_dir, "README.md") + model_card.save(card_path) diff --git a/diffusers/modeling_utils.py b/diffusers/modeling_utils.py new file mode 100644 index 000000000..fb613614a --- /dev/null +++ b/diffusers/modeling_utils.py @@ -0,0 +1,542 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch import Tensor, device + +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from requests import HTTPError + +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging + + +WEIGHTS_NAME = "diffusion_pytorch_model.bin" + + +logger = logging.get_logger(__name__) + + +def get_parameter_device(parameter: torch.nn.Module): + try: + return next(parameter.parameters()).device + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: torch.nn.Module): + try: + return next(parameter.parameters()).dtype + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +def load_state_dict(checkpoint_file: Union[str, os.PathLike]): + """ + Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. + """ + try: + return torch.load(checkpoint_file, map_location="cpu") + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' " + f"at '{checkpoint_file}'. " + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." + ) + + +def _load_state_dict_into_model(model_to_load, state_dict): + # Convert old format to new format if needed from a PyTorch state_dict + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: torch.nn.Module, prefix=""): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model_to_load) + + return error_msgs + + +class ModelMixin(torch.nn.Module): + r""" + Base class for all models. + + [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading + and saving models. + + - **config_name** ([`str`]) -- A filename under which the model should be stored when calling + [`~modeling_utils.ModelMixin.save_pretrained`]. + """ + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + + def __init__(self): + super().__init__() + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Callable = torch.save, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~modeling_utils.ModelMixin.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = self + + # Attach architecture to the config + # Save the config + if is_main_process: + model_to_save.save_config(save_directory) + + # Save the model + state_dict = model_to_save.state_dict() + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process: + os.remove(full_filename) + + # Save the model + save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME)) + + logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + Passing `use_auth_token=True`` is required when you want to use a private model. + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + from_auto_class = kwargs.pop("_from_auto", False) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + + user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + model, unused_kwargs = cls.from_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + **kwargs, + ) + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # Load model + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + # Load from a PyTorch checkpoint + model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) + ): + model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) + else: + raise EnvironmentError( + f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " + "login` and pass `use_auth_token=True`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}." + ) + except HTTPError as err: + raise EnvironmentError( + "There was a specific connection error when trying to load" + f" {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {WEIGHTS_NAME} or" + " \nCheckout your internet connection or see how to run the library in" + " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {WEIGHTS_NAME}" + ) + + # restore default dtype + state_dict = load_state_dict(model_file) + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + if output_loading_info: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = [k for k in state_dict.keys()] + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + + @property + def device(self) -> device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: + """ + Get number of (optionally, trainable or non-embeddings) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters + + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embeddings parameters + + Returns: + `int`: The number of parameters. + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.named_modules() + if isinstance(module_type, torch.nn.Embedding) + ] + non_embedding_parameters = [ + parameter for name, parameter in self.named_parameters() if name not in embedding_param_names + ] + return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) + else: + return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + + +def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: + """ + Recursively unwraps a model from potential containers (as used in distributed training). + + Args: + model (`torch.nn.Module`): The model to unwrap. + """ + # since there could be multiple levels of wrapping, unwrap recursively + if hasattr(model, "module"): + return unwrap_model(model.module) + else: + return model diff --git a/diffusers/models/__init__.py b/diffusers/models/__init__.py new file mode 100644 index 000000000..e0ac5c8d5 --- /dev/null +++ b/diffusers/models/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .unet_2d import UNet2DModel +from .unet_2d_condition import UNet2DConditionModel +from .vae import AutoencoderKL, VQModel diff --git a/diffusers/models/attention.py b/diffusers/models/attention.py new file mode 100644 index 000000000..de9c92691 --- /dev/null +++ b/diffusers/models/attention.py @@ -0,0 +1,333 @@ +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Uses three q, k, v linear layers to compute attention. + + Parameters: + channels (:obj:`int`): The number of channels in the input and output. + num_head_channels (:obj:`int`, *optional*): + The number of channels in each head. If None, then `num_heads` = 1. + num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + """ + + def __init__( + self, + channels: int, + num_head_channels: Optional[int] = None, + num_groups: int = 32, + rescale_output_factor: float = 1.0, + eps: float = 1e-5, + ): + super().__init__() + self.channels = channels + + self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) + + # define q,k,v as linear layers + self.query = nn.Linear(channels, channels) + self.key = nn.Linear(channels, channels) + self.value = nn.Linear(channels, channels) + + self.rescale_output_factor = rescale_output_factor + self.proj_attn = nn.Linear(channels, channels, 1) + + def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward(self, hidden_states): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + # transpose + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + # get scores + scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) + + attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) + + # compute attention output + hidden_states = torch.matmul(attention_probs, value_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image. + + Parameters: + in_channels (:obj:`int`): The number of channels in the input and output. + n_heads (:obj:`int`): The number of heads to use for multi-head attention. + d_head (:obj:`int`): The number of channels in each head. + depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. + context_dim (:obj:`int`, *optional*): The number of context dimensions to use. + """ + + def __init__( + self, + in_channels: int, + n_heads: int, + d_head: int, + depth: int = 1, + dropout: float = 0.0, + context_dim: Optional[int] = None, + ): + super().__init__() + self.n_heads = n_heads + self.d_head = d_head + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth) + ] + ) + + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def _set_attention_slice(self, slice_size): + for block in self.transformer_blocks: + block._set_attention_slice(slice_size) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) + for block in self.transformer_blocks: + x = block(x, context=context) + x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) + x = self.proj_out(x) + return x + x_in + + +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (:obj:`int`): The number of channels in the input and output. + n_heads (:obj:`int`): The number of heads to use for multi-head attention. + d_head (:obj:`int`): The number of channels in each head. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. + gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. + checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. + """ + + def __init__( + self, + dim: int, + n_heads: int, + d_head: int, + dropout=0.0, + context_dim: Optional[int] = None, + gated_ff: bool = True, + checkpoint: bool = True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def _set_attention_slice(self, slice_size): + self.attn1._slice_size = slice_size + self.attn2._slice_size = slice_size + + def forward(self, x, context=None): + x = x.contiguous() if x.device.type == "mps" else x + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class CrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (:obj:`int`): The number of channels in the query. + context_dim (:obj:`int`, *optional*): + The number of channels in the context. If not given, defaults to `query_dim`. + heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + def __init__( + self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0 + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = context_dim if context_dim is not None else query_dim + + self.scale = dim_head**-0.5 + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self._slice_size = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward(self, x, context=None, mask=None): + batch_size, sequence_length, dim = x.shape + + q = self.to_q(x) + context = context if context is not None else x + k = self.to_k(context) + v = self.to_v(context) + + q = self.reshape_heads_to_batch_dim(q) + k = self.reshape_heads_to_batch_dim(k) + v = self.reshape_heads_to_batch_dim(v) + + # TODO(PVP) - mask is currently never used. Remember to re-implement when used + + # attention, what we cannot get enough of + hidden_states = self._attention(q, k, v, sequence_length, dim) + + return self.to_out(hidden_states) + + def _attention(self, query, key, value, sequence_length, dim): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + attn_slice = ( + torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale + ) + attn_slice = attn_slice.softmax(dim=-1) + attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (:obj:`int`): The number of channels in the input. + dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + def __init__( + self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0 + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + project_in = GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (:obj:`int`): The number of channels in the input. + dim_out (:obj:`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) diff --git a/diffusers/models/embeddings.py b/diffusers/models/embeddings.py new file mode 100644 index 000000000..86ac074c1 --- /dev/null +++ b/diffusers/models/embeddings.py @@ -0,0 +1,115 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import numpy as np +import torch +from torch import nn + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent).to(device=timesteps.device) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): + super().__init__() + + self.linear_1 = nn.Linear(channel, time_embed_dim) + self.act = None + if act_fn == "silu": + self.act = nn.SiLU() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + + def forward(self, sample): + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__(self, embedding_size: int = 256, scale: float = 1.0): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + # to delete later + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + self.weight = self.W + + def forward(self, x): + x = torch.log(x) + x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out diff --git a/diffusers/models/resnet.py b/diffusers/models/resnet.py new file mode 100644 index 000000000..27fae24f7 --- /dev/null +++ b/diffusers/models/resnet.py @@ -0,0 +1,483 @@ +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Upsample2D(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + x = self.conv(x) + else: + x = self.Conv2d_0(x) + + return x + + +class Downsample2D(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + + assert x.shape[1] == self.channels + x = self.conv(x) + + return x + + +class FirUpsample2D(nn.Module): + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.use_conv = use_conv + self.fir_kernel = fir_kernel + self.out_channels = out_channels + + def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): + """Fused `upsample_2d()` followed by `Conv2d()`. + + Args: + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary: + order. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + weight: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as + `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + + # Setup filter kernel. + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = np.asarray(kernel, dtype=np.float32) + if kernel.ndim == 1: + kernel = np.outer(kernel, kernel) + kernel /= np.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + + if self.use_conv: + convH = weight.shape[2] + convW = weight.shape[3] + inC = weight.shape[1] + + p = (kernel.shape[0] - factor) - (convW - 1) + + stride = (factor, factor) + # Determine data dimensions. + stride = [1, 1, factor, factor] + output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW) + output_padding = ( + output_shape[0] - (x.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (x.shape[3] - 1) * stride[1] - convW, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + inC = weight.shape[1] + num_groups = x.shape[1] // inC + + # Transpose weights. + weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) + weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) + weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) + + x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0) + + x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) + else: + p = kernel.shape[0] - factor + x = upfirdn2d_native( + x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) + ) + + return x + + def forward(self, x): + if self.use_conv: + height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel) + height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2) + + return height + + +class FirDownsample2D(nn.Module): + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.fir_kernel = fir_kernel + self.use_conv = use_conv + self.out_channels = out_channels + + def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): + """Fused `Conv2d()` followed by `downsample_2d()`. + + Args: + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary: + order. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH, + filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // + numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * + factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: + Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same + datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = np.asarray(kernel, dtype=np.float32) + if kernel.ndim == 1: + kernel = np.outer(kernel, kernel) + kernel /= np.sum(kernel) + + kernel = kernel * gain + + if self.use_conv: + _, _, convH, convW = weight.shape + p = (kernel.shape[0] - factor) + (convW - 1) + s = [factor, factor] + x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2)) + x = F.conv2d(x, weight, stride=s, padding=0) + else: + p = kernel.shape[0] - factor + x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + + return x + + def forward(self, x): + if self.use_conv: + x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) + x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2) + + return x + + +class ResnetBlock2D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + kernel=None, + output_scale_factor=1.0, + use_nin_shortcut=None, + up=False, + down=False, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + else: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") + + self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut + + self.conv_shortcut = None + if self.use_nin_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + hidden_states = x + + # make sure hidden states is in float32 + # when running in half-precision + hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + x = self.upsample(x) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + x = self.downsample(x) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + + # make sure hidden states is in float32 + # when running in half-precision + hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + x = self.conv_shortcut(x) + + out = (x + hidden_states) / self.output_scale_factor + + return out + + +class Mish(torch.nn.Module): + def forward(self, x): + return x * torch.tanh(torch.nn.functional.softplus(x)) + + +def upsample_2d(x, kernel=None, factor=2, gain=1): + r"""Upsample2D a batch of 2D images with the given filter. + + Args: + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given + filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a: + multiple of the upsampling factor. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = np.asarray(kernel, dtype=np.float32) + if kernel.ndim == 1: + kernel = np.outer(kernel, kernel) + kernel /= np.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + p = kernel.shape[0] - factor + return upfirdn2d_native( + x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) + ) + + +def downsample_2d(x, kernel=None, factor=2, gain=1): + r"""Downsample2D a batch of 2D images with the given filter. + + Args: + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the + given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the + specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its + shape is a multiple of the downsampling factor. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = np.asarray(kernel, dtype=np.float32) + if kernel.ndim == 1: + kernel = np.outer(kernel, kernel) + kernel /= np.sum(kernel) + + kernel = kernel * gain + p = kernel.shape[0] - factor + return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + + +def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): + up_x = up_y = up + down_x = down_y = down + pad_x0 = pad_y0 = pad[0] + pad_x1 = pad_y1 = pad[1] + + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + + # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535 + if input.device.type == "mps": + out = out.to("cpu") + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out.to(input.device) # Move back to mps if necessary + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/diffusers/models/unet_2d.py b/diffusers/models/unet_2d.py new file mode 100644 index 000000000..c3ab621a2 --- /dev/null +++ b/diffusers/models/unet_2d.py @@ -0,0 +1,246 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..utils import BaseOutput +from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block + + +@dataclass +class UNet2DOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states output. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet2DModel(ModelMixin, ConfigMixin): + r""" + UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): + Input sample size. + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. + freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to : + obj:`False`): Whether to flip sin to cos for fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block + types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(224, 448, 672, 896)`): Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. + mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. + downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. + norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization. + norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization. + """ + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 3, + out_channels: int = 3, + center_input_sample: bool = False, + time_embedding_type: str = "positional", + freq_shift: int = 0, + flip_sin_to_cos: bool = True, + down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), + up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), + block_out_channels: Tuple[int] = (224, 448, 672, 896), + layers_per_block: int = 2, + mid_block_scale_factor: float = 1, + downsample_padding: int = 1, + act_fn: str = "silu", + attention_head_dim: int = 8, + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + attn_num_head_channels=attention_head_dim, + downsample_padding=downsample_padding, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + attn_num_head_channels=attention_head_dim, + resnet_groups=norm_num_groups, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + attn_num_head_channels=attention_head_dim, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + return_dict: bool = True, + ) -> Union[UNet2DOutput, Tuple]: + """r + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) + + t_emb = self.time_proj(timesteps) + emb = self.time_embedding(t_emb) + + # 2. pre-process + skip_sample = sample + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "skip_conv"): + sample, res_samples, skip_sample = downsample_block( + hidden_states=sample, temb=emb, skip_sample=skip_sample + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb) + + # 5. up + skip_sample = None + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "skip_conv"): + sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) + else: + sample = upsample_block(sample, res_samples, emb) + + # 6. post-process + # make sure hidden states is in float32 + # when running in half-precision + sample = self.conv_norm_out(sample.float()).type(sample.dtype) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if skip_sample is not None: + sample += skip_sample + + if self.config.time_embedding_type == "fourier": + timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) + sample = sample / timesteps + + if not return_dict: + return (sample,) + + return UNet2DOutput(sample=sample) diff --git a/diffusers/models/unet_2d_condition.py b/diffusers/models/unet_2d_condition.py new file mode 100644 index 000000000..92caaca92 --- /dev/null +++ b/diffusers/models/unet_2d_condition.py @@ -0,0 +1,270 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..utils import BaseOutput +from .embeddings import TimestepEmbedding, Timesteps +from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet2DConditionModel(ModelMixin, ConfigMixin): + r""" + UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + sample_size (`int`, *optional*): The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: int = 8, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + downsample_padding=downsample_padding, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + resnet_groups=norm_num_groups, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.config.attention_head_dim % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.config.attention_head_dim}" + ) + if slice_size is not None and slice_size > self.config.attention_head_dim: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.config.attention_head_dim}" + ) + + for block in self.down_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_attention_slice(slice_size) + + self.mid_block.set_attention_slice(slice_size) + + for block in self.up_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_attention_slice(slice_size) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + """r + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps.to(dtype=torch.float32) + timesteps = timesteps[None].to(device=sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: + sample, res_samples = downsample_block( + hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # 5. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples) + + # 6. post-process + # make sure hidden states is in float32 + # when running in half-precision + sample = self.conv_norm_out(sample.float()).type(sample.dtype) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/diffusers/models/unet_blocks.py b/diffusers/models/unet_blocks.py new file mode 100644 index 000000000..9e0621653 --- /dev/null +++ b/diffusers/models/unet_blocks.py @@ -0,0 +1,1481 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import numpy as np + +# limitations under the License. +import torch +from torch import nn + +from .attention import AttentionBlock, SpatialTransformer +from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + cross_attention_dim=None, + downsample_padding=None, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + ) + elif down_block_type == "AttnDownBlock2D": + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + ) + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + cross_attention_dim=None, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + elif up_block_type == "AttnUpBlock2D": + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attn_num_head_channels=attn_num_head_channels, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attn_num_head_channels=attn_num_head_channels, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + **kwargs, + ): + super().__init__() + + self.attention_type = attention_type + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + AttentionBlock( + in_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.attention_type == "default": + hidden_states = attn(hidden_states) + else: + hidden_states = attn(hidden_states, encoder_states) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + cross_attention_dim=1280, + **kwargs, + ): + super().__init__() + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + SpatialTransformer( + in_channels, + attn_num_head_channels, + in_channels // attn_num_head_channels, + depth=1, + context_dim=cross_attention_dim, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + if slice_size is not None and slice_size > self.attn_num_head_channels: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class AttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + SpatialTransformer( + out_channels, + attn_num_head_channels, + out_channels // attn_num_head_channels, + depth=1, + context_dim=cross_attention_dim, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + if slice_size is not None and slice_size > self.attn_num_head_channels: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=encoder_hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnDownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + attentions = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + num_groups=resnet_groups, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnSkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=np.sqrt(2.0), + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + self.attention_type = attention_type + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_nin_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class SkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor=np.sqrt(2.0), + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_nin_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class AttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_type="default", + attn_num_head_channels=1, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + for resnet, attn in zip(self.resnets, self.attentions): + + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + SpatialTransformer( + out_channels, + attn_num_head_channels, + out_channels // attn_num_head_channels, + depth=1, + context_dim=cross_attention_dim, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + if slice_size is not None and slice_size > self.attn_num_head_channels: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None): + for resnet, attn in zip(self.resnets, self.attentions): + + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=encoder_hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + for resnet in self.resnets: + + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnUpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + num_groups=resnet_groups, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnSkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=np.sqrt(2.0), + upsample_padding=1, + add_upsample=True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + self.attention_type = attention_type + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(resnet_in_channels + res_skip_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_nin_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + hidden_states = self.attentions[0](hidden_states) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class SkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor=np.sqrt(2.0), + add_upsample=True, + upsample_padding=1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min((resnet_in_channels + res_skip_channels) // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_nin_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample diff --git a/diffusers/models/vae.py b/diffusers/models/vae.py new file mode 100644 index 000000000..82748cb5b --- /dev/null +++ b/diffusers/models/vae.py @@ -0,0 +1,581 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..utils import BaseOutput +from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block + + +@dataclass +class DecoderOutput(BaseOutput): + """ + Output of decoding method. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Decoded output sample of the model. Output of the last layer of the model. + """ + + sample: torch.FloatTensor + + +@dataclass +class VQEncoderOutput(BaseOutput): + """ + Output of VQModel encoding method. + + Args: + latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Encoded output sample of the model. Output of the last layer of the model. + """ + + latents: torch.FloatTensor + + +@dataclass +class AutoencoderKLOutput(BaseOutput): + """ + Output of AutoencoderKL encoding method. + + Args: + latent_dist (`DiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. + `DiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent_dist: "DiagonalGaussianDistribution" + + +class Encoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + act_fn="silu", + double_z=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + attn_num_head_channels=None, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=32, + temb_channels=None, + ) + + # out + num_groups_out = 32 + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + def forward(self, x): + sample = x + sample = self.conv_in(sample) + + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class Decoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + act_fn="silu", + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=32, + temb_channels=None, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attn_num_head_channels=None, + temb_channels=None, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = 32 + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + def forward(self, z): + sample = z + sample = self.conv_in(sample) + + # middle + sample = self.mid_block(sample) + + # up + for up_block in self.up_blocks: + sample = up_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class VectorQuantizer(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t()) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: + device = self.parameters.device + sample_device = "cpu" if device.type == "mps" else device + sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device) + x = self.mean + self.std * sample + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) + + def mode(self): + return self.mean + + +class VQModel(ModelMixin, ConfigMixin): + r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray + Kavukcuoglu. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 3, + sample_size: int = 32, + num_vq_embeddings: int = 256, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + double_z=False, + ) + + self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + self.quantize = VectorQuantizer( + num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False + ) + self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + ) + + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: + h = self.encoder(x) + h = self.quant_conv(h) + + if not return_dict: + return (h,) + + return VQEncoderOutput(latents=h) + + def decode( + self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + h = self.encode(x).latents + dec = self.decode(h).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + +class AutoencoderKL(ModelMixin, ConfigMixin): + r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma + and Max Welling. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + sample_size: int = 32, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + ) + + self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/diffusers/onnx_utils.py b/diffusers/onnx_utils.py new file mode 100644 index 000000000..e840565dd --- /dev/null +++ b/diffusers/onnx_utils.py @@ -0,0 +1,189 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil +from pathlib import Path +from typing import Optional, Union + +import numpy as np + +from huggingface_hub import hf_hub_download + +from .utils import is_onnx_available, logging + + +if is_onnx_available(): + import onnxruntime as ort + + +ONNX_WEIGHTS_NAME = "model.onnx" + + +logger = logging.get_logger(__name__) + + +class OnnxRuntimeModel: + base_model_prefix = "onnx_model" + + def __init__(self, model=None, **kwargs): + logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") + self.model = model + self.model_save_dir = kwargs.get("model_save_dir", None) + self.latest_model_name = kwargs.get("latest_model_name", "model.onnx") + + def __call__(self, **kwargs): + inputs = {k: np.array(v) for k, v in kwargs.items()} + return self.model.run(None, inputs) + + @staticmethod + def load_model(path: Union[str, Path], provider=None): + """ + Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` + + Arguments: + path (`str` or `Path`): + Directory from which to load + provider(`str`, *optional*): + Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider` + """ + if provider is None: + logger.info("No onnxruntime provider specified, using CPUExecutionProvider") + provider = "CPUExecutionProvider" + + return ort.InferenceSession(path, providers=[provider]) + + def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the + latest_model_name. + + Arguments: + save_directory (`str` or `Path`): + Directory where to save the model file. + file_name(`str`, *optional*): + Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the + model with a different name. + """ + model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + + src_path = self.model_save_dir.joinpath(self.latest_model_name) + dst_path = Path(save_directory).joinpath(model_file_name) + if not src_path.samefile(dst_path): + shutil.copyfile(src_path, dst_path) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + **kwargs, + ): + """ + Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class + method.: + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + # saving model weights/files + self._save_pretrained(save_directory, **kwargs) + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + use_auth_token: Optional[Union[bool, str, None]] = None, + revision: Optional[Union[str, None]] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + provider: Optional[str] = None, + **kwargs, + ): + """ + Load a model from a directory or the HF Hub. + + Arguments: + model_id (`str` or `Path`): + Directory from which to load + use_auth_token (`str` or `bool`): + Is needed to load models from a private or gated repository + revision (`str`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id + cache_dir (`Union[str, Path]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + file_name(`str`): + Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load + different model files from the same repository or directory. + provider(`str`): + The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`. + kwargs (`Dict`, *optional*): + kwargs will be passed to the model during initialization + """ + model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + # load model from local directory + if os.path.isdir(model_id): + model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider) + kwargs["model_save_dir"] = Path(model_id) + # load model from hub + else: + # download model + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=model_file_name, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + ) + kwargs["model_save_dir"] = Path(model_cache_path).parent + kwargs["latest_model_name"] = Path(model_cache_path).name + model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider) + return cls(model=model, **kwargs) + + @classmethod + def from_pretrained( + cls, + model_id: Union[str, Path], + force_download: bool = True, + use_auth_token: Optional[str] = None, + cache_dir: Optional[str] = None, + **model_kwargs, + ): + revision = None + if len(str(model_id).split("@")) == 2: + model_id, revision = model_id.split("@") + + return cls._from_pretrained( + model_id=model_id, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + use_auth_token=use_auth_token, + **model_kwargs, + ) diff --git a/diffusers/optimization.py b/diffusers/optimization.py new file mode 100644 index 000000000..e7b836b4a --- /dev/null +++ b/diffusers/optimization.py @@ -0,0 +1,275 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch optimization for diffusion models.""" + +import math +from enum import Enum +from typing import Optional, Union + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +from .utils import logging + + +logger = logging.get_logger(__name__) + + +class SchedulerType(Enum): + LINEAR = "linear" + COSINE = "cosine" + COSINE_WITH_RESTARTS = "cosine_with_restarts" + POLYNOMIAL = "polynomial" + CONSTANT = "constant" + CONSTANT_WITH_WARMUP = "constant_with_warmup" + + +def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) + + +def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate + increases linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) + return 1.0 + + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) + + +def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) + ) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases + linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`int`, *optional*, defaults to 1): + The number of hard restarts to use. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + if progress >= 1.0: + return 0.0 + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_polynomial_decay_schedule_with_warmup( + optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 +): + """ + Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the + optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + lr_end (`float`, *optional*, defaults to 1e-7): + The end LR. + power (`float`, *optional*, defaults to 1.0): + Power factor. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT + implementation at + https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + + """ + + lr_init = optimizer.defaults["lr"] + if not (lr_init > lr_end): + raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining**power + lr_end + return decay / lr_init # as LambdaLR multiplies by lr_init + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +TYPE_TO_SCHEDULER_FUNCTION = { + SchedulerType.LINEAR: get_linear_schedule_with_warmup, + SchedulerType.COSINE: get_cosine_schedule_with_warmup, + SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, + SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, + SchedulerType.CONSTANT: get_constant_schedule, + SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, +} + + +def get_scheduler( + name: Union[str, SchedulerType], + optimizer: Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, +): + """ + Unified API to get any scheduler from its name. + + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) diff --git a/diffusers/pipeline_utils.py b/diffusers/pipeline_utils.py new file mode 100644 index 000000000..84ee9e20f --- /dev/null +++ b/diffusers/pipeline_utils.py @@ -0,0 +1,417 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import torch + +import diffusers +import PIL +from huggingface_hub import snapshot_download +from PIL import Image +from tqdm.auto import tqdm + +from .configuration_utils import ConfigMixin +from .utils import DIFFUSERS_CACHE, BaseOutput, logging + + +INDEX_FILE = "diffusion_pytorch_model.bin" + + +logger = logging.get_logger(__name__) + + +LOADABLE_CLASSES = { + "diffusers": { + "ModelMixin": ["save_pretrained", "from_pretrained"], + "SchedulerMixin": ["save_config", "from_config"], + "DiffusionPipeline": ["save_pretrained", "from_pretrained"], + "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], + }, + "transformers": { + "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], + "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], + "PreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + }, +} + +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + + +@dataclass +class ImagePipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class DiffusionPipeline(ConfigMixin): + r""" + Base class for all models. + + [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines + and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to: + + - move all PyTorch modules to the device of your choice + - enabling/disabling the progress bar for the denoising iteration + + Class attributes: + + - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all + compenents of the diffusion pipeline. + """ + config_name = "model_index.json" + + def register_modules(self, **kwargs): + # import it here to avoid circular import + from diffusers import pipelines + + for name, module in kwargs.items(): + # retrive library + library = module.__module__.split(".")[0] + + # check if the module is a pipeline module + pipeline_dir = module.__module__.split(".")[-2] + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: + library = pipeline_dir + + # retrive class_name + class_name = module.__class__.__name__ + + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + def save_pretrained(self, save_directory: Union[str, os.PathLike]): + """ + Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to + a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading + method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + self.save_config(save_directory) + + model_index_dict = dict(self.config) + model_index_dict.pop("_class_name") + model_index_dict.pop("_diffusers_version") + model_index_dict.pop("_module", None) + + for pipeline_component_name in model_index_dict.keys(): + sub_model = getattr(self, pipeline_component_name) + model_cls = sub_model.__class__ + + save_method_name = None + # search for the model's base class in LOADABLE_CLASSES + for library_name, library_classes in LOADABLE_CLASSES.items(): + library = importlib.import_module(library_name) + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class) + if issubclass(model_cls, class_candidate): + # if we found a suitable base class in LOADABLE_CLASSES then grab its save method + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + save_method = getattr(sub_model, save_method_name) + save_method(os.path.join(save_directory, pipeline_component_name)) + + def to(self, torch_device: Optional[Union[str, torch.device]] = None): + if torch_device is None: + return self + + module_names, _ = self.extract_init_dict(dict(self.config)) + for name in module_names.keys(): + module = getattr(self, name) + if isinstance(module, torch.nn.Module): + module.to(torch_device) + return self + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + module_names, _ = self.extract_init_dict(dict(self.config)) + for name in module_names.keys(): + module = getattr(self, name) + if isinstance(module, torch.nn.Module): + return module.device + return torch.device("cpu") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. + + The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on + https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like + `CompVis/ldm-text2im-large-256`. + - A path to a *directory* containing pipeline weights saved using + [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. specify the folder name here. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the + speficic pipeline class. The overritten components are then directly passed to the pipelines `__init__` + method. See example below for more information. + + + + Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.* + `"CompVis/stable-diffusion-v1-4"` + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + Examples: + + ```py + >>> from diffusers import DiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # Download pipeline that requires an authorization token + >>> # For more information on access tokens, please refer to this section + >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) + >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) + + >>> # Download pipeline, but overwrite scheduler + >>> from diffusers import LMSDiscreteScheduler + + >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + >>> pipeline = DiffusionPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True + ... ) + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + provider = kwargs.pop("provider", None) + + # 1. Download the checkpoints and configs + # use snapshot download here to get it working from from_pretrained + if not os.path.isdir(pretrained_model_name_or_path): + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + else: + cached_folder = pretrained_model_name_or_path + + config_dict = cls.get_config_dict(cached_folder) + + # 2. Load the pipeline class, if using custom module then load it from the hub + # if we load from explicit class, let's use it + if cls != DiffusionPipeline: + pipeline_class = cls + else: + diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) + pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + + init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + + init_kwargs = {} + + # import it here to avoid circular import + from diffusers import pipelines + + # 3. Load each module in the pipeline + for name, (library_name, class_name) in init_dict.items(): + is_pipeline_module = hasattr(pipelines, library_name) + loaded_sub_model = None + + # if the model is in a pipeline module, then we load it from the pipeline + if name in passed_class_obj: + # 1. check that passed_class_obj has correct parent class + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" + f" {expected_class_obj}" + ) + else: + logger.warn( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + # set passed class object + loaded_sub_model = passed_class_obj[name] + elif is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + importable_classes = ALL_IMPORTABLE_CLASSES + class_candidates = {c: class_obj for c in importable_classes.keys()} + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + if loaded_sub_model is None: + load_method_name = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + load_method = getattr(class_obj, load_method_name) + + loading_kwargs = {} + if issubclass(class_obj, torch.nn.Module): + loading_kwargs["torch_dtype"] = torch_dtype + if issubclass(class_obj, diffusers.OnnxRuntimeModel): + loading_kwargs["provider"] = provider + + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) + else: + # else load from the root directory + loaded_sub_model = load_method(cached_folder, **loading_kwargs) + + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + + # 4. Instantiate the pipeline + model = pipeline_class(**init_kwargs) + return model + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + def progress_bar(self, iterable): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + return tqdm(iterable, **self._progress_bar_config) + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs diff --git a/diffusers/pipelines/__init__.py b/diffusers/pipelines/__init__.py new file mode 100644 index 000000000..3e2aeb4fb --- /dev/null +++ b/diffusers/pipelines/__init__.py @@ -0,0 +1,19 @@ +from ..utils import is_onnx_available, is_transformers_available +from .ddim import DDIMPipeline +from .ddpm import DDPMPipeline +from .latent_diffusion_uncond import LDMPipeline +from .pndm import PNDMPipeline +from .score_sde_ve import ScoreSdeVePipeline +from .stochastic_karras_ve import KarrasVePipeline + + +if is_transformers_available(): + from .latent_diffusion import LDMTextToImagePipeline + from .stable_diffusion import ( + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + ) + +if is_transformers_available() and is_onnx_available(): + from .stable_diffusion import StableDiffusionOnnxPipeline diff --git a/diffusers/pipelines/ddim/__init__.py b/diffusers/pipelines/ddim/__init__.py new file mode 100644 index 000000000..8fd31868a --- /dev/null +++ b/diffusers/pipelines/ddim/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline_ddim import DDIMPipeline diff --git a/diffusers/pipelines/ddim/pipeline_ddim.py b/diffusers/pipelines/ddim/pipeline_ddim.py new file mode 100644 index 000000000..33f6064db --- /dev/null +++ b/diffusers/pipelines/ddim/pipeline_ddim.py @@ -0,0 +1,117 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +import warnings +from typing import Optional, Tuple, Union + +import torch + +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class DDIMPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + def __init__(self, unet, scheduler): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[torch.Generator] = None, + eta: float = 0.0, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + eta (`float`, *optional*, defaults to 0.0): + The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + # eta corresponds to η in paper and should be between [0, 1] + + # Sample gaussian noise to begin loop + image = torch.randn( + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + generator=generator, + ) + image = image.to(self.device) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(image, t).sample + + # 2. predict previous mean of image x_t-1 and add variance depending on eta + # do x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image, eta).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/ddpm/__init__.py b/diffusers/pipelines/ddpm/__init__.py new file mode 100644 index 000000000..8889bdae1 --- /dev/null +++ b/diffusers/pipelines/ddpm/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline_ddpm import DDPMPipeline diff --git a/diffusers/pipelines/ddpm/pipeline_ddpm.py b/diffusers/pipelines/ddpm/pipeline_ddpm.py new file mode 100644 index 000000000..71103bbe4 --- /dev/null +++ b/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -0,0 +1,106 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +import warnings +from typing import Optional, Tuple, Union + +import torch + +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class DDPMPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + def __init__(self, unet, scheduler): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + # Sample gaussian noise to begin loop + image = torch.randn( + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + generator=generator, + ) + image = image.to(self.device) + + # set step values + self.scheduler.set_timesteps(1000) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(image, t).sample + + # 2. compute previous image: x_t -> t_t-1 + image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/latent_diffusion/__init__.py b/diffusers/pipelines/latent_diffusion/__init__.py new file mode 100644 index 000000000..c481b38cf --- /dev/null +++ b/diffusers/pipelines/latent_diffusion/__init__.py @@ -0,0 +1,6 @@ +# flake8: noqa +from ...utils import is_transformers_available + + +if is_transformers_available(): + from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline diff --git a/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py new file mode 100644 index 000000000..b39840f24 --- /dev/null +++ b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -0,0 +1,705 @@ +import inspect +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_outputs import BaseModelOutput +from transformers.modeling_utils import PreTrainedModel +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import logging + +from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler + + +class LDMTextToImagePipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](ttps://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vqvae: Union[VQModel, AutoencoderKL], + bert: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + unet: Union[UNet2DModel, UNet2DConditionModel], + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 256, + width: Optional[int] = 256, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 1.0, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 256): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 256): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 1.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at + the, usually at the expense of lower image quality. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get unconditional embeddings for classifier free guidance + if guidance_scale != 1.0: + uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") + uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0] + + # get prompt text embeddings + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") + text_embeddings = self.bert(text_input.input_ids.to(self.device))[0] + + latents = torch.randn( + (batch_size, self.unet.in_channels, height // 8, width // 8), + generator=generator, + ) + latents = latents.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale == 1.0: + # guidance_scale of 1 means no guidance + latents_input = latents + context = text_embeddings + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = torch.cat([latents] * 2) + context = torch.cat([uncond_embeddings, text_embeddings]) + + # predict the noise residual + noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample + # perform guidance + if guidance_scale != 1.0: + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vqvae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + +################################################################################ +# Code for the text transformer model +################################################################################ +""" PyTorch LDMBERT model.""" + + +logger = logging.get_logger(__name__) + +LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "ldm-bert", + # See all LDMBert models at https://huggingface.co/models?filter=ldmbert +] + + +LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "ldm-bert": "https://huggingface.co/ldm-bert/resolve/main/config.json", +} + + +""" LDMBERT model configuration""" + + +class LDMBertConfig(PretrainedConfig): + model_type = "ldmbert" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=30522, + max_position_embeddings=77, + encoder_layers=32, + encoder_ffn_dim=5120, + encoder_attention_heads=8, + head_dim=64, + encoder_layerdrop=0.0, + activation_function="gelu", + d_model=1280, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=False, + use_cache=True, + pad_token_id=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.head_dim = head_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + super().__init__(pad_token_id=pad_token_id, **kwargs) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert +class LDMBertAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = head_dim + self.inner_dim = head_dim * num_heads + + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.out_proj = nn.Linear(self.inner_dim, embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class LDMBertEncoderLayer(nn.Module): + def __init__(self, config: LDMBertConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = LDMBertAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + head_dim=config.head_dim, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert +class LDMBertPreTrainedModel(PreTrainedModel): + config_class = LDMBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (LDMBertEncoder,)): + module.gradient_checkpointing = value + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class LDMBertEncoder(LDMBertPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`LDMBertEncoderLayer`]. + + Args: + config: LDMBertConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: LDMBertConfig): + super().__init__(config) + + self.dropout = config.dropout + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) + self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim) + self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + seq_len = input_shape[1] + if position_ids is None: + position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1)) + embed_pos = self.embed_positions(position_ids) + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class LDMBertModel(LDMBertPreTrainedModel): + def __init__(self, config: LDMBertConfig): + super().__init__(config) + self.model = LDMBertEncoder(config) + self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return outputs diff --git a/diffusers/pipelines/latent_diffusion_uncond/__init__.py b/diffusers/pipelines/latent_diffusion_uncond/__init__.py new file mode 100644 index 000000000..0826ca753 --- /dev/null +++ b/diffusers/pipelines/latent_diffusion_uncond/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline_latent_diffusion_uncond import LDMPipeline diff --git a/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py new file mode 100644 index 000000000..4979d88fe --- /dev/null +++ b/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -0,0 +1,108 @@ +import inspect +import warnings +from typing import Optional, Tuple, Union + +import torch + +from ...models import UNet2DModel, VQModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler + + +class LDMPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latens. + """ + + def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[torch.Generator] = None, + eta: float = 0.0, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + Number of images to generate. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + latents = torch.randn( + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + generator=generator, + ) + latents = latents.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(self.scheduler.timesteps): + # predict the noise residual + noise_prediction = self.unet(latents, t).sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample + + # decode the image latents with the VAE + image = self.vqvae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/pndm/__init__.py b/diffusers/pipelines/pndm/__init__.py new file mode 100644 index 000000000..6fc46aaab --- /dev/null +++ b/diffusers/pipelines/pndm/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline_pndm import PNDMPipeline diff --git a/diffusers/pipelines/pndm/pipeline_pndm.py b/diffusers/pipelines/pndm/pipeline_pndm.py new file mode 100644 index 000000000..f3dff1a9a --- /dev/null +++ b/diffusers/pipelines/pndm/pipeline_pndm.py @@ -0,0 +1,111 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +import warnings +from typing import Optional, Tuple, Union + +import torch + +from ...models import UNet2DModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import PNDMScheduler + + +class PNDMPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image. + """ + + unet: UNet2DModel + scheduler: PNDMScheduler + + def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, `optional`, defaults to 1): The number of images to generate. + num_inference_steps (`int`, `optional`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator`, `optional`): A [torch + generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose + between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a + [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + # For more information on the sampling method you can take a look at Algorithm 2 of + # the official paper: https://arxiv.org/pdf/2202.09778.pdf + + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + # Sample gaussian noise to begin loop + image = torch.randn( + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + generator=generator, + ) + image = image.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + for t in self.progress_bar(self.scheduler.timesteps): + model_output = self.unet(image, t).sample + + image = self.scheduler.step(model_output, t, image).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/score_sde_ve/__init__.py b/diffusers/pipelines/score_sde_ve/__init__.py new file mode 100644 index 000000000..000d61f6e --- /dev/null +++ b/diffusers/pipelines/score_sde_ve/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline_score_sde_ve import ScoreSdeVePipeline diff --git a/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py new file mode 100644 index 000000000..604e2b54c --- /dev/null +++ b/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +import warnings +from typing import Optional, Tuple, Union + +import torch + +from ...models import UNet2DModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import ScoreSdeVeScheduler + + +class ScoreSdeVePipeline(DiffusionPipeline): + r""" + Parameters: + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): + The [`ScoreSdeVeScheduler`] scheduler to be used in combination with `unet` to denoise the encoded image. + """ + unet: UNet2DModel + scheduler: ScoreSdeVeScheduler + + def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 2000, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + img_size = self.unet.config.sample_size + shape = (batch_size, 3, img_size, img_size) + + model = self.unet + + sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max + sample = sample.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_sigmas(num_inference_steps) + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) + + # correction step + for _ in range(self.scheduler.correct_steps): + model_output = self.unet(sample, sigma_t).sample + sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample + + # prediction step + model_output = model(sample, sigma_t).sample + output = self.scheduler.step_pred(model_output, t, sample, generator=generator) + + sample, sample_mean = output.prev_sample, output.prev_sample_mean + + sample = sample_mean.clamp(0, 1) + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + sample = self.numpy_to_pil(sample) + + if not return_dict: + return (sample,) + + return ImagePipelineOutput(images=sample) diff --git a/diffusers/pipelines/stable_diffusion/__init__.py b/diffusers/pipelines/stable_diffusion/__init__.py new file mode 100644 index 000000000..5ffda93f1 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/__init__.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np + +import PIL +from PIL import Image + +from ...utils import BaseOutput, is_onnx_available, is_transformers_available + + +@dataclass +class StableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: List[bool] + + +if is_transformers_available(): + from .pipeline_stable_diffusion import StableDiffusionPipeline + from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline + from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline + from .safety_checker import StableDiffusionSafetyChecker + +if is_transformers_available() and is_onnx_available(): + from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py new file mode 100644 index 000000000..e8e076137 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -0,0 +1,398 @@ +# Modification of the original file by O. Teytaud for facilitating genetic stable diffusion. + +import inspect +import os +import numpy as np +import random +import warnings +from typing import List, Optional, Union + +import torch + +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +class StableDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + +# def get_latent(self, image): +# return self.vae.encode(image) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_device = "cpu" if self.device.type == "mps" else self.device + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + latents_intermediate_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + speedup = 1 + if latents is None: + latents = torch.randn( + latents_intermediate_shape, + generator=generator, + device=latents_device, + ) + if len(os.environ["forcedlatent"]) > 10: + stri = os.environ["forcedlatent"] + print(f"we get a forcing for the latent z: {stri[:20]}.") + if len(eval(stri)) == 1: + stri = str(eval(stri)[0]) + speedup = 1 + latents = np.array(list(eval(stri))).flatten() + #latents = latents + np.exp(0.1 * np.random.randn()) * np.random.rand(len(latents)) + #latents = np.sqrt(len(latents) / np.sum(latents ** 2)) * latents + #latents = np.sqrt(len(latents)) * latents / np.sqrt(np.sum(latents ** 2)) + print(f"As an array, this is {latents[:10]}") + print(f"immediately after loading latent ==> {sum(latents.flatten()**2) / len(latents.flatten())}") + latents = torch.from_numpy(latents.reshape((1,4,64,64))).float().to(latents_device) + os.environ["forcedlatent"] = "" + good = eval(os.environ["good"]) + bad = eval(os.environ["bad"]) + print(f"{len(good)} good and {len(bad)} bad") + i_believe_in_evolution = len(good) > 0 and len(bad) > 10 + print(f"I believe in evolution = {i_believe_in_evolution}") + if i_believe_in_evolution: + from sklearn import tree + from sklearn.neural_network import MLPClassifier + #from sklearn.neighbors import NearestCentroid + from sklearn.linear_model import LogisticRegression + #z = (np.random.randn(4*64*64)) + z = latents.cpu().numpy().flatten() + if os.environ.get("skl", "tree") == "tree": + clf = tree.DecisionTreeClassifier()#min_samples_split=0.1) + elif os.environ.get("skl", "tree") == "logit": + clf = LogisticRegression() + else: + clf = MLPClassifier(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(5, 2), random_state=1) + #clf = NearestCentroid() + + + + X=good + bad + Y = [1] * len(good) + [0] * len(bad) + clf = clf.fit(X,Y) + epsilon = 0.0001 # for astronauts + epsilon = 1.0 + + def loss(x): + return clf.predict_proba([x])[0][0] # for astronauts + #return clf.predict_proba([(1-epsilon)*z+epsilon*x])[0][0] # for astronauts + #return clf.predict_proba([z+epsilon*x])[0][0] + + + budget = int(os.environ.get("budget", "300")) + if i_believe_in_evolution and budget > 20: + import nevergrad as ng + #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) + #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) + optim_class = ng.optimizers.registry[os.environ.get("ngoptim", "DiscreteLenglerOnePlusOne")] + #nevergrad_optimizer = ng.optimizers.DiscreteLenglerOnePlusOne(len(z), budget) + nevergrad_optimizer = optim_class(len(z), budget) + #nevergrad_optimizer = ng.optimizers.DiscreteOnePlusOne(len(z), budget) +# for k in range(5): +# z1 = np.array(random.choice(good)) +# z2 = np.array(random.choice(good)) +# z3 = np.array(random.choice(good)) +# z4 = np.array(random.choice(good)) +# z5 = np.array(random.choice(good)) +# #z = 0.99 * z1 + 0.01 * (z2+z3+z4+z5)/4. +# z = 0.2 * (z1 + z2 + z3 + z4 + z5) +# mu = int(os.environ.get("mu", "5")) +# parents = [z1, z2, z3, z4, z5] +# weights = [np.exp(np.random.randn() - i * float(os.environ.get("decay", "1."))) for i in range(5)] +# z = weights[0] * z1 +# for u in range(mu): +# if u > 0: +# z += weights[u] * parents[u] +# z = (1. / sum(weights[:mu])) * z +# z = np.sqrt(len(z)) * z / np.linalg.norm(z) +# +# #for u in range(len(z)): +# # z[u] = random.choice([z1[u],z2[u],z3[u],z4[u],z5[u]]) +# nevergrad_optimizer.suggest + if len(os.environ["forcedlatent"]) > 0: + print("we get a forcing for the latent z.") + z0 = eval(os.environ["forcedlatent"]) + #nevergrad_optimizer.suggest(eval(os.environ["forcedlatent"])) + else: + z0 = z + for i in range(budget): + x = nevergrad_optimizer.ask() + z = z0 + float(os.environ.get("epsilon", "0.001")) * x.value + z = np.sqrt(len(z)) * z / np.linalg.norm(z) + l = loss(z) + nevergrad_optimizer.tell(x, l) + if np.log2(i+1) == int(np.log2(i+1)): + print(f"iteration {i} --> {l}") + print("var/variable = ", sum(z**2)/len(z)) + #z = (1.-epsilon) * z + epsilon * x / np.sqrt(np.sum(x ** 2)) + if l < 0.0000001 and os.environ.get("earlystop", "False") in ["true", "True"]: + print(f"we find proba(bad)={l}") + break + x = nevergrad_optimizer.recommend().value + z = z0 + float(os.environ.get("epsilon", "0.001")) * x + z = np.sqrt(len(z)) * z / np.linalg.norm(z) + latents = torch.from_numpy(z.reshape(latents_intermediate_shape)).float() #.half() + else: + if latents.shape != latents_intermediate_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_intermediate_shape}") + print(f"latent ==> {sum(latents.flatten()**2) / len(latents.flatten())}") + print(f"latent ==> {torch.max(latents)}") + print(f"latent ==> {torch.min(latents)}") + os.environ["latent_sd"] = str(list(latents.flatten().cpu().numpy())) + for i in [2, 3]: + latents = torch.repeat_interleave(latents, repeats=latents_shape[i] // latents_intermediate_shape[i], dim=i) #/ np.sqrt(np.sqrt(latents_shape[i] // latents_intermediate_shape[i])) + latents = latents.float().to(self.device) + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps // speedup, **extra_set_kwargs) + + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[i] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + # predict the noise residual + print(f"text_embeddings.shape={text_embeddings.shape}") + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # scale and decode the image latents with vae + #os.environ["latent_sd"] = str(list(latents.flatten().cpu().detach().numpy())) + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + # run safety checker + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py new file mode 100644 index 000000000..475ceef4f --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -0,0 +1,291 @@ +import inspect +from typing import List, Optional, Union + +import numpy as np +import torch + +import PIL +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +class StableDiffusionImg2ImgPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `set_attention_slice` + self.enable_attention_slice(None) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + offset = 0 + if accepts_offset: + offset = 1 + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + if not isinstance(init_image, torch.FloatTensor): + init_image = preprocess(init_image) + + # encode the init image into latents and scale the latents + init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + # expand init_latents for batch_size + init_latents = torch.cat([init_latents] * batch_size) + + # get the original timestep using init_timestep + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + if isinstance(self.scheduler, LMSDiscreteScheduler): + timesteps = torch.tensor( + [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device + ) + else: + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device) + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + t_start = max(num_inference_steps - init_timestep + offset, 0) + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])): + t_index = t_start + i + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[t_index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = latent_model_input.to(self.unet.dtype) + t = t.to(self.unet.dtype) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents.to(self.vae.dtype)).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + # run safety checker + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py new file mode 100644 index 000000000..05ea84ae0 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -0,0 +1,309 @@ +import inspect +from typing import List, Optional, Union + +import numpy as np +import torch + +import PIL +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, PNDMScheduler +from ...utils import logging +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +class StableDiffusionInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + scheduler = scheduler.set_format("pt") + logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `set_attention_slice` + self.enable_attention_slice(None) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be + converted to a single channel (luminance) before use. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + offset = 0 + if accepts_offset: + offset = 1 + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + # preprocess image + init_image = preprocess_image(init_image).to(self.device) + + # encode the init image into latents and scale the latents + init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + + init_latents = 0.18215 * init_latents + + # Expand init_latents for batch_size + init_latents = torch.cat([init_latents] * batch_size) + init_latents_orig = init_latents + + # preprocess mask + mask = preprocess_mask(mask_image).to(self.device) + mask = torch.cat([mask] * batch_size) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + t_start = max(num_inference_steps - init_timestep + offset, 0) + for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + # run safety checker + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py new file mode 100644 index 000000000..7ff3ff22f --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -0,0 +1,165 @@ +import inspect +from typing import List, Optional, Union + +import numpy as np + +from transformers import CLIPFeatureExtractor, CLIPTokenizer + +from ...onnx_utils import OnnxRuntimeModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from . import StableDiffusionPipelineOutput + + +class StableDiffusionOnnxPipeline(DiffusionPipeline): + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPFeatureExtractor + + def __init__( + self, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + scheduler = scheduler.set_format("np") + self.register_modules( + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + latents: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ): + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ) + uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + latents_shape = (batch_size, 4, height // 8, width // 8) + if latents is None: + latents = np.random.randn(*latents_shape).astype(np.float32) + elif latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[i] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + # predict the noise residual + noise_pred = self.unet( + sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings + ) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae_decoder(latent_sample=latents)[0] + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + # run safety checker + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") + image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/safety_checker.py b/diffusers/pipelines/stable_diffusion/safety_checker.py new file mode 100644 index 000000000..3ebc05c91 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -0,0 +1,106 @@ +import numpy as np +import torch +import torch.nn as nn + +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class StableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.register_buffer("concept_embeds_weights", torch.ones(17)) + self.register_buffer("special_care_embeds_weights", torch.ones(3)) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concet_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concet_idx] + concept_threshold = self.special_care_embeds_weights[concet_idx].item() + result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concet_idx] > 0: + result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]}) + adjustment = 0.01 + + for concet_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concet_idx] + concept_threshold = self.concept_embeds_weights[concet_idx].item() + result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concet_idx] > 0: + result_img["bad_concepts"].append(concet_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + #for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + # if has_nsfw_concept: + # images[idx] = np.zeros(images[idx].shape) # black image +# +# if any(has_nsfw_concepts): +# logger.warning( +# "Potential NSFW content was detected in one or more images. A black image will be returned instead." +# " Try again with a different prompt and/or seed." +# ) + + return images, has_nsfw_concepts + + @torch.inference_mode() + def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + images[has_nsfw_concepts] = 0.0 # black image + + return images, has_nsfw_concepts diff --git a/diffusers/pipelines/stochastic_karras_ve/__init__.py b/diffusers/pipelines/stochastic_karras_ve/__init__.py new file mode 100644 index 000000000..db2582043 --- /dev/null +++ b/diffusers/pipelines/stochastic_karras_ve/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline_stochastic_karras_ve import KarrasVePipeline diff --git a/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py new file mode 100644 index 000000000..15266544d --- /dev/null +++ b/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +import warnings +from typing import Optional, Tuple, Union + +import torch + +from ...models import UNet2DModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import KarrasVeScheduler + + +class KarrasVePipeline(DiffusionPipeline): + r""" + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and + the VE column of Table 1 from [1] for reference. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic + differential equations." https://arxiv.org/abs/2011.13456 + + Parameters: + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`KarrasVeScheduler`]): + Scheduler for the diffusion process to be used in combination with `unet` to denoise the encoded image. + """ + + # add type hints for linting + unet: UNet2DModel + scheduler: KarrasVeScheduler + + def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + img_size = self.unet.config.sample_size + shape = (batch_size, 3, img_size, img_size) + + model = self.unet + + # sample x_0 ~ N(0, sigma_0^2 * I) + sample = torch.randn(*shape) * self.scheduler.config.sigma_max + sample = sample.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # here sigma_t == t_i from the paper + sigma = self.scheduler.schedule[t] + sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 + + # 1. Select temporarily increased noise level sigma_hat + # 2. Add new noise to move from sample_i to sample_hat + sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator) + + # 3. Predict the noise residual given the noise magnitude `sigma_hat` + # The model inputs and output are adjusted by following eq. (213) in [1]. + model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample + + # 4. Evaluate dx/dt at sigma_hat + # 5. Take Euler step from sigma to sigma_prev + step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat) + + if sigma_prev != 0: + # 6. Apply 2nd order correction + # The model inputs and output are adjusted by following eq. (213) in [1]. + model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample + step_output = self.scheduler.step_correct( + model_output, + sigma_hat, + sigma_prev, + sample_hat, + step_output.prev_sample, + step_output["derivative"], + ) + sample = step_output.prev_sample + + sample = (sample / 2 + 0.5).clamp(0, 1) + image = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(sample) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/schedulers/__init__.py b/diffusers/schedulers/__init__.py new file mode 100644 index 000000000..20c25f351 --- /dev/null +++ b/diffusers/schedulers/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..utils import is_scipy_available +from .scheduling_ddim import DDIMScheduler +from .scheduling_ddpm import DDPMScheduler +from .scheduling_karras_ve import KarrasVeScheduler +from .scheduling_pndm import PNDMScheduler +from .scheduling_sde_ve import ScoreSdeVeScheduler +from .scheduling_sde_vp import ScoreSdeVpScheduler +from .scheduling_utils import SchedulerMixin + + +if is_scipy_available(): + from .scheduling_lms_discrete import LMSDiscreteScheduler +else: + from ..utils.dummy_scipy_objects import * # noqa F403 diff --git a/diffusers/schedulers/scheduling_ddim.py b/diffusers/schedulers/scheduling_ddim.py new file mode 100644 index 000000000..894d63bf2 --- /dev/null +++ b/diffusers/schedulers/scheduling_ddim.py @@ -0,0 +1,261 @@ +# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas, dtype=np.float32) + + +class DDIMScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functios. + + For more details, see the original paper: https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): TODO + timestep_values (`np.ndarray`, optional): TODO + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + set_alpha_to_one (`bool`, default `True`): + if alpha for final step is 1 or the final alpha of the "non-previous" one. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + timestep_values: Optional[np.ndarray] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + tensor_format: str = "pt", + ): + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this paratemer simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def set_timesteps(self, num_inference_steps: int, offset: int = 0): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + offset (`int`): TODO + """ + self.num_inference_steps = num_inference_steps + self.timesteps = np.arange( + 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps + )[::-1].copy() + self.timesteps += offset + self.set_format(tensor_format=self.tensor_format) + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): TODO + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointingc to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = self.clip(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the model_output is always re-derived from the clipped x_0 in Glide + model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + device = model_output.device if torch.is_tensor(model_output) else "cpu" + noise = torch.randn(model_output.shape, generator=generator).to(device) + variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise + + if not torch.is_tensor(model_output): + variance = variance.numpy() + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: Union[torch.FloatTensor, np.ndarray], + noise: Union[torch.FloatTensor, np.ndarray], + timesteps: Union[torch.IntTensor, np.ndarray], + ) -> Union[torch.FloatTensor, np.ndarray]: + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_ddpm.py b/diffusers/schedulers/scheduling_ddpm.py new file mode 100644 index 000000000..4fbfb9038 --- /dev/null +++ b/diffusers/schedulers/scheduling_ddpm.py @@ -0,0 +1,264 @@ +# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas, dtype=np.float32) + + +class DDPMScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functios. + + For more details, see the original paper: https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): TODO + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + variance_type: str = "fixed_small", + clip_sample: bool = True, + tensor_format: str = "pt", + ): + + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + elif beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.one = np.array(1.0) + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + self.variance_type = variance_type + + def set_timesteps(self, num_inference_steps: int): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) + self.num_inference_steps = num_inference_steps + self.timesteps = np.arange( + 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps + )[::-1].copy() + self.set_format(tensor_format=self.tensor_format) + + def _get_variance(self, t, predicted_variance=None, variance_type=None): + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] + + if variance_type is None: + variance_type = self.config.variance_type + + # hacks - were probs added for training stability + if variance_type == "fixed_small": + variance = self.clip(variance, min_value=1e-20) + # for rl-diffuser https://arxiv.org/abs/2205.09991 + elif variance_type == "fixed_small_log": + variance = self.log(self.clip(variance, min_value=1e-20)) + elif variance_type == "fixed_large": + variance = self.betas[t] + elif variance_type == "fixed_large_log": + # Glide max_log + variance = self.log(self.betas[t]) + elif variance_type == "learned": + return predicted_variance + elif variance_type == "learned_range": + min_log = variance + max_log = self.betas[t] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + predict_epsilon=True, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + predict_epsilon (`bool`): + optional flag to use when model predicts the samples directly instead of the noise, epsilon. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + t = timestep + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if predict_epsilon: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + else: + pred_original_sample = model_output + + # 3. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = self.clip(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t + current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if t > 0: + noise = self.randn_like(model_output, generator=generator) + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample,) + + return SchedulerOutput(prev_sample=pred_prev_sample) + + def add_noise( + self, + original_samples: Union[torch.FloatTensor, np.ndarray], + noise: Union[torch.FloatTensor, np.ndarray], + timesteps: Union[torch.IntTensor, np.ndarray], + ) -> Union[torch.FloatTensor, np.ndarray]: + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_karras_ve.py b/diffusers/schedulers/scheduling_karras_ve.py new file mode 100644 index 000000000..3a2370cfc --- /dev/null +++ b/diffusers/schedulers/scheduling_karras_ve.py @@ -0,0 +1,208 @@ +# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class KarrasVeOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Derivate of predicted original image sample (x_0). + """ + + prev_sample: torch.FloatTensor + derivative: torch.FloatTensor + + +class KarrasVeScheduler(SchedulerMixin, ConfigMixin): + """ + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and + the VE column of Table 1 from [1] for reference. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic + differential equations." https://arxiv.org/abs/2011.13456 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functios. + + For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of + Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the + optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. + + Args: + sigma_min (`float`): minimum noise magnitude + sigma_max (`float`): maximum noise magnitude + s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. + A reasonable range is [1.000, 1.011]. + s_churn (`float`): the parameter controlling the overall amount of stochasticity. + A reasonable range is [0, 100]. + s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). + A reasonable range is [0, 10]. + s_max (`float`): the end value of the sigma range where we add noise. + A reasonable range is [0.2, 80]. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + sigma_min: float = 0.02, + sigma_max: float = 100, + s_noise: float = 1.007, + s_churn: float = 80, + s_min: float = 0.05, + s_max: float = 50, + tensor_format: str = "pt", + ): + # setable values + self.num_inference_steps = None + self.timesteps = None + self.schedule = None # sigma(t_i) + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_timesteps(self, num_inference_steps: int): + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + + """ + self.num_inference_steps = num_inference_steps + self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() + self.schedule = [ + (self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1))) + for i in self.timesteps + ] + self.schedule = np.array(self.schedule, dtype=np.float32) + + self.set_format(tensor_format=self.tensor_format) + + def add_noise_to_input( + self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None + ) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]: + """ + Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a + higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. + + TODO Args: + """ + if self.s_min <= sigma <= self.s_max: + gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1) + else: + gamma = 0 + + # sample eps ~ N(0, S_noise^2 * I) + eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) + sigma_hat = sigma + gamma * sigma + sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) + + return sample_hat, sigma_hat + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + sigma_hat: float, + sigma_prev: float, + sample_hat: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[KarrasVeOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). + Returns: + [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`: + [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + + pred_original_sample = sample_hat + sigma_hat * model_output + derivative = (sample_hat - pred_original_sample) / sigma_hat + sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative + + if not return_dict: + return (sample_prev, derivative) + + return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) + + def step_correct( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + sigma_hat: float, + sigma_prev: float, + sample_hat: Union[torch.FloatTensor, np.ndarray], + sample_prev: Union[torch.FloatTensor, np.ndarray], + derivative: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[KarrasVeOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. TODO complete description + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO + sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO + derivative (`torch.FloatTensor` or `np.ndarray`): TODO + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO + + """ + pred_original_sample = sample_prev + sigma_prev * model_output + derivative_corr = (sample_prev - pred_original_sample) / sigma_prev + sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) + + if not return_dict: + return (sample_prev, derivative) + + return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) + + def add_noise(self, original_samples, noise, timesteps): + raise NotImplementedError() diff --git a/diffusers/schedulers/scheduling_lms_discrete.py b/diffusers/schedulers/scheduling_lms_discrete.py new file mode 100644 index 000000000..1381587fe --- /dev/null +++ b/diffusers/schedulers/scheduling_lms_discrete.py @@ -0,0 +1,193 @@ +# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from scipy import integrate + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functios. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): TODO + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + timestep_values (`np.ndarry`, optional): TODO + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + timestep_values: Optional[np.ndarray] = None, + tensor_format: str = "pt", + ): + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.derivatives = [] + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def get_lms_coefficient(self, order, t, current_order): + """ + Compute a linear multistep coefficient. + + Args: + order (TODO): + t (TODO): + current_order (TODO): + """ + + def lms_derivative(tau): + prod = 1.0 + for k in range(order): + if current_order == k: + continue + prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k]) + return prod + + integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0] + + return integrated_coeff + + def set_timesteps(self, num_inference_steps: int): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + + low_idx = np.floor(self.timesteps).astype(int) + high_idx = np.ceil(self.timesteps).astype(int) + frac = np.mod(self.timesteps, 1.0) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + self.sigmas = np.concatenate([sigmas, [0.0]]) + + self.derivatives = [] + + self.set_format(tensor_format=self.tensor_format) + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + order: int = 4, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + order: coefficient for multi-step inference. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + sigma = self.sigmas[timestep] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma * model_output + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + self.derivatives.append(derivative) + if len(self.derivatives) > order: + self.derivatives.pop(0) + + # 3. Compute linear multistep coefficients + order = min(timestep + 1, order) + lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)] + + # 4. Compute previous sample based on the derivatives path + prev_sample = sample + sum( + coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives)) + ) + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: Union[torch.FloatTensor, np.ndarray], + noise: Union[torch.FloatTensor, np.ndarray], + timesteps: Union[torch.IntTensor, np.ndarray], + ) -> Union[torch.FloatTensor, np.ndarray]: + sigmas = self.match_shape(self.sigmas[timesteps], noise) + noisy_samples = original_samples + noise * sigmas + + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_pndm.py b/diffusers/schedulers/scheduling_pndm.py new file mode 100644 index 000000000..b43d88bba --- /dev/null +++ b/diffusers/schedulers/scheduling_pndm.py @@ -0,0 +1,378 @@ +# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas, dtype=np.float32) + + +class PNDMScheduler(SchedulerMixin, ConfigMixin): + """ + Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, + namely Runge-Kutta method and a linear multi-step method. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functios. + + For more details, see the original paper: https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): TODO + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays + skip_prk_steps (`bool`): + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms steps; defaults to `False`. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + tensor_format: str = "pt", + skip_prk_steps: bool = False, + ): + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + + self.one = np.array(1.0) + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + # running values + self.cur_model_output = 0 + self.counter = 0 + self.cur_sample = None + self.ets = [] + + # setable values + self.num_inference_steps = None + self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self._offset = 0 + self.prk_timesteps = None + self.plms_timesteps = None + self.timesteps = None + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + offset (`int`): TODO + """ + self.num_inference_steps = num_inference_steps + self._timesteps = list( + range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) + ) + self._offset = offset + self._timesteps = np.array([t + self._offset for t in self._timesteps]) + + if self.config.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + self.prk_timesteps = np.array([]) + self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[ + ::-1 + ].copy() + else: + prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( + np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order + ) + self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() + self.plms_timesteps = self._timesteps[:-3][ + ::-1 + ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy + + self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + + self.ets = [] + self.counter = 0 + self.set_format(tensor_format=self.tensor_format) + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: + return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) + else: + return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) + + def step_prk( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 + prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1]) + timestep = self.prk_timesteps[self.counter // 4 * 4] + + if self.counter % 4 == 0: + self.cur_model_output += 1 / 6 * model_output + self.ets.append(model_output) + self.cur_sample = sample + elif (self.counter - 1) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 2) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 3) % 4 == 0: + model_output = self.cur_model_output + 1 / 6 * model_output + self.cur_model_output = 0 + + # cur_sample should not be `None` + cur_sample = self.cur_sample if self.cur_sample is not None else sample + + prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) + self.counter += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def step_plms( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if not self.config.skip_prk_steps and len(self.ets) < 3: + raise ValueError( + f"{self.__class__} can only be run AFTER scheduler has been run " + "in 'prk' mode for at least 12 iterations " + "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " + "for more information." + ) + + prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0) + + if self.counter != 1: + self.ets.append(model_output) + else: + prev_timestep = timestep + timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps + + if len(self.ets) == 1 and self.counter == 0: + model_output = model_output + self.cur_sample = sample + elif len(self.ets) == 1 and self.counter == 1: + model_output = (model_output + self.ets[-1]) / 2 + sample = self.cur_sample + self.cur_sample = None + elif len(self.ets) == 2: + model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 + else: + model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) + self.counter += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation + + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # model_output -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset] + alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset] + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) + + # corresponds to denominator of e_θ(x_t, t) in formula (9) + model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev + ) ** (0.5) + + # full formula (9) + prev_sample = ( + sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff + ) + + return prev_sample + + def add_noise( + self, + original_samples: Union[torch.FloatTensor, np.ndarray], + noise: Union[torch.FloatTensor, np.ndarray], + timesteps: Union[torch.IntTensor, np.ndarray], + ) -> torch.Tensor: + # mps requires indices to be in the same device, so we use cpu as is the default with cuda + timesteps = timesteps.to(self.alphas_cumprod.device) + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_sde_ve.py b/diffusers/schedulers/scheduling_sde_ve.py new file mode 100644 index 000000000..e187f0796 --- /dev/null +++ b/diffusers/schedulers/scheduling_sde_ve.py @@ -0,0 +1,283 @@ +# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +@dataclass +class SdeVeOutput(BaseOutput): + """ + Output class for the ScoreSdeVeScheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. + """ + + prev_sample: torch.FloatTensor + prev_sample_mean: torch.FloatTensor + + +class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): + """ + The variance exploding stochastic differential equation (SDE) scheduler. + + For more information, see the original paper: https://arxiv.org/abs/2011.13456 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functios. + + Args: + snr (`float`): + coefficient weighting the step from the model_output sample (from the network) to the random noise. + sigma_min (`float`): + initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the + distribution of the data. + sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. + sampling_eps (`float`): the end value of sampling, where timesteps decrease progessively from 1 to + epsilon. + correct_steps (`int`): number of correction steps performed on a produced sample. + tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler. + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 2000, + snr: float = 0.15, + sigma_min: float = 0.01, + sigma_max: float = 1348.0, + sampling_eps: float = 1e-5, + correct_steps: int = 1, + tensor_format: str = "pt", + ): + # setable values + self.timesteps = None + + self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). + + """ + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + self.timesteps = np.linspace(1, sampling_eps, num_inference_steps) + elif tensor_format == "pt": + self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps) + else: + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def set_sigmas( + self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None + ): + """ + Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. + + The sigmas control the weight of the `drift` and `diffusion` components of sample update. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sigma_min (`float`, optional): + initial noise scale value (overrides value given at Scheduler instantiation). + sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). + sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). + + """ + sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min + sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + if self.timesteps is None: + self.set_timesteps(num_inference_steps, sampling_eps) + + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)) + self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) + elif tensor_format == "pt": + self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)) + self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) + else: + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def get_adjacent_sigma(self, timesteps, t): + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1]) + elif tensor_format == "pt": + return torch.where( + timesteps == 0, + torch.zeros_like(t.to(timesteps.device)), + self.discrete_sigmas[timesteps - 1].to(timesteps.device), + ) + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def set_seed(self, seed): + warnings.warn( + "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a" + " generator instead.", + DeprecationWarning, + ) + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + np.random.seed(seed) + elif tensor_format == "pt": + torch.manual_seed(seed) + else: + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def step_pred( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + **kwargs, + ) -> Union[SdeVeOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if "seed" in kwargs and kwargs["seed"] is not None: + self.set_seed(kwargs["seed"]) + + if self.timesteps is None: + raise ValueError( + "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + timestep = timestep * torch.ones( + sample.shape[0], device=sample.device + ) # torch.repeat_interleave(timestep, sample.shape[0]) + timesteps = (timestep * (len(self.timesteps) - 1)).long() + + # mps requires indices to be in the same device, so we use cpu as is the default with cuda + timesteps = timesteps.to(self.discrete_sigmas.device) + + sigma = self.discrete_sigmas[timesteps].to(sample.device) + adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) + drift = self.zeros_like(sample) + diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 + + # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) + # also equation 47 shows the analog from SDE models to ancestral sampling methods + drift = drift - diffusion[:, None, None, None] ** 2 * model_output + + # equation 6: sample noise for the diffusion term of + noise = self.randn_like(sample, generator=generator) + prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep + # TODO is the variable diffusion the correct scaling term for the noise? + prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g + + if not return_dict: + return (prev_sample, prev_sample_mean) + + return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean) + + def step_correct( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + sample: Union[torch.FloatTensor, np.ndarray], + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + **kwargs, + ) -> Union[SchedulerOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. This is often run repeatedly + after making the prediction for the previous timestep. + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if "seed" in kwargs and kwargs["seed"] is not None: + self.set_seed(kwargs["seed"]) + + if self.timesteps is None: + raise ValueError( + "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" + # sample noise for correction + noise = self.randn_like(sample, generator=generator) + + # compute step size from the model_output, the noise, and the snr + grad_norm = self.norm(model_output) + noise_norm = self.norm(noise) + step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 + step_size = step_size * torch.ones(sample.shape[0]).to(sample.device) + # self.repeat_scalar(step_size, sample.shape[0]) + + # compute corrected sample: model_output term and noise term + prev_sample_mean = sample + step_size[:, None, None, None] * model_output + prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_sde_vp.py b/diffusers/schedulers/scheduling_sde_vp.py new file mode 100644 index 000000000..66e6ec661 --- /dev/null +++ b/diffusers/schedulers/scheduling_sde_vp.py @@ -0,0 +1,81 @@ +# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin + + +class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): + """ + The variance preserving stochastic differential equation (SDE) scheduler. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functios. + + For more information, see the original paper: https://arxiv.org/abs/2011.13456 + + UNDER CONSTRUCTION + + """ + + @register_to_config + def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): + + self.sigmas = None + self.discrete_sigmas = None + self.timesteps = None + + def set_timesteps(self, num_inference_steps): + self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) + + def step_pred(self, score, x, t): + if self.timesteps is None: + raise ValueError( + "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + # TODO(Patrick) better comments + non-PyTorch + # postprocess model score + log_mean_coeff = ( + -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min + ) + std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) + score = -score / std[:, None, None, None] + + # compute + dt = -1.0 / len(self.timesteps) + + beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) + drift = -0.5 * beta_t[:, None, None, None] * x + diffusion = torch.sqrt(beta_t) + drift = drift - diffusion[:, None, None, None] ** 2 * score + x_mean = x + drift * dt + + # add noise + noise = torch.randn_like(x) + x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise + + return x, x_mean + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_utils.py b/diffusers/schedulers/scheduling_utils.py new file mode 100644 index 000000000..f2bcd73ac --- /dev/null +++ b/diffusers/schedulers/scheduling_utils.py @@ -0,0 +1,125 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Union + +import numpy as np +import torch + +from ..utils import BaseOutput + + +SCHEDULER_CONFIG_NAME = "scheduler_config.json" + + +@dataclass +class SchedulerOutput(BaseOutput): + """ + Base class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class SchedulerMixin: + """ + Mixin containing common functions for the schedulers. + """ + + config_name = SCHEDULER_CONFIG_NAME + ignore_for_config = ["tensor_format"] + + def set_format(self, tensor_format="pt"): + self.tensor_format = tensor_format + if tensor_format == "pt": + for key, value in vars(self).items(): + if isinstance(value, np.ndarray): + setattr(self, key, torch.from_numpy(value)) + + return self + + def clip(self, tensor, min_value=None, max_value=None): + tensor_format = getattr(self, "tensor_format", "pt") + + if tensor_format == "np": + return np.clip(tensor, min_value, max_value) + elif tensor_format == "pt": + return torch.clamp(tensor, min_value, max_value) + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def log(self, tensor): + tensor_format = getattr(self, "tensor_format", "pt") + + if tensor_format == "np": + return np.log(tensor) + elif tensor_format == "pt": + return torch.log(tensor) + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]): + """ + Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. + + Args: + values: an array or tensor of values to extract. + broadcast_array: an array with a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + Returns: + a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + + tensor_format = getattr(self, "tensor_format", "pt") + values = values.flatten() + + while len(values.shape) < len(broadcast_array.shape): + values = values[..., None] + if tensor_format == "pt": + values = values.to(broadcast_array.device) + + return values + + def norm(self, tensor): + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + return np.linalg.norm(tensor) + elif tensor_format == "pt": + return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean() + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def randn_like(self, tensor, generator=None): + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + return np.random.randn(*np.shape(tensor)) + elif tensor_format == "pt": + # return torch.randn_like(tensor) + return torch.randn(tensor.shape, layout=tensor.layout, generator=generator).to(tensor.device) + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def zeros_like(self, tensor): + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + return np.zeros_like(tensor) + elif tensor_format == "pt": + return torch.zeros_like(tensor) + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") diff --git a/diffusers/testing_utils.py b/diffusers/testing_utils.py new file mode 100644 index 000000000..ff8b6aa9b --- /dev/null +++ b/diffusers/testing_utils.py @@ -0,0 +1,61 @@ +import os +import random +import unittest +from distutils.util import strtobool + +import torch + +from packaging import version + + +global_rng = random.Random() +torch_device = "cuda" if torch.cuda.is_available() else "cpu" +is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12") + +if is_torch_higher_equal_than_1_12: + torch_device = "mps" if torch.backends.mps.is_available() else torch_device + + +def parse_flag_from_env(key, default=False): + try: + value = os.environ[key] + except KeyError: + # KEY isn't set, default to `default`. + _value = default + else: + # KEY is set, convert it to True or False. + try: + _value = strtobool(value) + except ValueError: + # More values are supported, but let's keep the message simple. + raise ValueError(f"If set, {key} must be yes or no.") + return _value + + +_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) + + +def floats_tensor(shape, scale=1.0, rng=None, name=None): + """Creates a random float32 tensor""" + if rng is None: + rng = global_rng + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.random() * scale) + + return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() + + +def slow(test_case): + """ + Decorator marking a test as slow. + + Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) diff --git a/diffusers/training_utils.py b/diffusers/training_utils.py new file mode 100644 index 000000000..fa1694161 --- /dev/null +++ b/diffusers/training_utils.py @@ -0,0 +1,125 @@ +import copy +import os +import random + +import numpy as np +import torch + + +def enable_full_determinism(seed: int): + """ + Helper function for reproducible behavior during distributed training. See + - https://pytorch.org/docs/stable/notes/randomness.html for pytorch + """ + # set seed first + set_seed(seed) + + # Enable PyTorch deterministic mode. This potentially requires either the environment + # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, + # depending on the CUDA version, so we set them both here + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True) + + # Enable CUDNN deterministic mode + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def set_seed(seed: int): + """ + Args: + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + seed (`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # ^^ safe to call this function even if cuda is not available + + +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__( + self, + model, + update_after_step=0, + inv_gamma=1.0, + power=2 / 3, + min_value=0.0, + max_value=0.9999, + device=None, + ): + """ + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 2/3. + min_value (float): The minimum EMA decay rate. Default: 0. + """ + + self.averaged_model = copy.deepcopy(model).eval() + self.averaged_model.requires_grad_(False) + + self.update_after_step = update_after_step + self.inv_gamma = inv_gamma + self.power = power + self.min_value = min_value + self.max_value = max_value + + if device is not None: + self.averaged_model = self.averaged_model.to(device=device) + + self.decay = 0.0 + self.optimization_step = 0 + + def get_decay(self, optimization_step): + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + value = 1 - (1 + step / self.inv_gamma) ** -self.power + + if step <= 0: + return 0.0 + + return max(self.min_value, min(value, self.max_value)) + + @torch.no_grad() + def step(self, new_model): + ema_state_dict = {} + ema_params = self.averaged_model.state_dict() + + self.decay = self.get_decay(self.optimization_step) + + for key, param in new_model.named_parameters(): + if isinstance(param, dict): + continue + try: + ema_param = ema_params[key] + except KeyError: + ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) + ema_params[key] = ema_param + + if not param.requires_grad: + ema_params[key].copy_(param.to(dtype=ema_param.dtype).data) + ema_param = ema_params[key] + else: + ema_param.mul_(self.decay) + ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) + + ema_state_dict[key] = ema_param + + for key, param in new_model.named_buffers(): + ema_state_dict[key] = param + + self.averaged_model.load_state_dict(ema_state_dict, strict=False) + self.optimization_step += 1 diff --git a/diffusers/utils/__init__.py b/diffusers/utils/__init__.py new file mode 100644 index 000000000..c00a28e10 --- /dev/null +++ b/diffusers/utils/__init__.py @@ -0,0 +1,53 @@ +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +from .import_utils import ( + ENV_VARS_TRUE_AND_AUTO_VALUES, + ENV_VARS_TRUE_VALUES, + USE_JAX, + USE_TF, + USE_TORCH, + DummyObject, + is_flax_available, + is_inflect_available, + is_modelcards_available, + is_onnx_available, + is_scipy_available, + is_tf_available, + is_torch_available, + is_transformers_available, + is_unidecode_available, + requires_backends, +) +from .logging import get_logger +from .outputs import BaseOutput + + +logger = get_logger(__name__) + + +hf_cache_home = os.path.expanduser( + os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) +) +default_cache_path = os.path.join(hf_cache_home, "diffusers") + + +CONFIG_NAME = "config.json" +HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" +DIFFUSERS_CACHE = default_cache_path +DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" +HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) diff --git a/diffusers/utils/dummy_scipy_objects.py b/diffusers/utils/dummy_scipy_objects.py new file mode 100644 index 000000000..3706c5754 --- /dev/null +++ b/diffusers/utils/dummy_scipy_objects.py @@ -0,0 +1,11 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class LMSDiscreteScheduler(metaclass=DummyObject): + _backends = ["scipy"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["scipy"]) diff --git a/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py b/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py new file mode 100644 index 000000000..8c2aec218 --- /dev/null +++ b/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py @@ -0,0 +1,10 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa +from ..utils import DummyObject, requires_backends + + +class GradTTSPipeline(metaclass=DummyObject): + _backends = ["transformers", "inflect", "unidecode"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers", "inflect", "unidecode"]) diff --git a/diffusers/utils/dummy_transformers_and_onnx_objects.py b/diffusers/utils/dummy_transformers_and_onnx_objects.py new file mode 100644 index 000000000..2e34b5ce0 --- /dev/null +++ b/diffusers/utils/dummy_transformers_and_onnx_objects.py @@ -0,0 +1,11 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class StableDiffusionOnnxPipeline(metaclass=DummyObject): + _backends = ["transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers", "onnx"]) diff --git a/diffusers/utils/dummy_transformers_objects.py b/diffusers/utils/dummy_transformers_objects.py new file mode 100644 index 000000000..e05eb814d --- /dev/null +++ b/diffusers/utils/dummy_transformers_objects.py @@ -0,0 +1,32 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class LDMTextToImagePipeline(metaclass=DummyObject): + _backends = ["transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers"]) + + +class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers"]) + + +class StableDiffusionInpaintPipeline(metaclass=DummyObject): + _backends = ["transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers"]) + + +class StableDiffusionPipeline(metaclass=DummyObject): + _backends = ["transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers"]) diff --git a/diffusers/utils/import_utils.py b/diffusers/utils/import_utils.py new file mode 100644 index 000000000..1f5e95ada --- /dev/null +++ b/diffusers/utils/import_utils.py @@ -0,0 +1,274 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Import utilities: Utilities related to imports and our lazy inits. +""" +import importlib.util +import os +import sys +from collections import OrderedDict + +from packaging import version + +from . import logging + + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + +USE_TF = os.environ.get("USE_TF", "AUTO").upper() +USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() +USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() + +_torch_version = "N/A" +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + _torch_available = importlib.util.find_spec("torch") is not None + if _torch_available: + try: + _torch_version = importlib_metadata.version("torch") + logger.info(f"PyTorch version {_torch_version} available.") + except importlib_metadata.PackageNotFoundError: + _torch_available = False +else: + logger.info("Disabling PyTorch because USE_TF is set") + _torch_available = False + + +_tf_version = "N/A" +if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: + _tf_available = importlib.util.find_spec("tensorflow") is not None + if _tf_available: + candidates = ( + "tensorflow", + "tensorflow-cpu", + "tensorflow-gpu", + "tf-nightly", + "tf-nightly-cpu", + "tf-nightly-gpu", + "intel-tensorflow", + "intel-tensorflow-avx512", + "tensorflow-rocm", + "tensorflow-macos", + "tensorflow-aarch64", + ) + _tf_version = None + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + for pkg in candidates: + try: + _tf_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _tf_available = _tf_version is not None + if _tf_available: + if version.parse(_tf_version) < version.parse("2"): + logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.") + _tf_available = False + else: + logger.info(f"TensorFlow version {_tf_version} available.") +else: + logger.info("Disabling Tensorflow because USE_TORCH is set") + _tf_available = False + + +if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: + _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None + if _flax_available: + try: + _jax_version = importlib_metadata.version("jax") + _flax_version = importlib_metadata.version("flax") + logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") + except importlib_metadata.PackageNotFoundError: + _flax_available = False +else: + _flax_available = False + + +_transformers_available = importlib.util.find_spec("transformers") is not None +try: + _transformers_version = importlib_metadata.version("transformers") + logger.debug(f"Successfully imported transformers version {_transformers_version}") +except importlib_metadata.PackageNotFoundError: + _transformers_available = False + + +_inflect_available = importlib.util.find_spec("inflect") is not None +try: + _inflect_version = importlib_metadata.version("inflect") + logger.debug(f"Successfully imported inflect version {_inflect_version}") +except importlib_metadata.PackageNotFoundError: + _inflect_available = False + + +_unidecode_available = importlib.util.find_spec("unidecode") is not None +try: + _unidecode_version = importlib_metadata.version("unidecode") + logger.debug(f"Successfully imported unidecode version {_unidecode_version}") +except importlib_metadata.PackageNotFoundError: + _unidecode_available = False + + +_modelcards_available = importlib.util.find_spec("modelcards") is not None +try: + _modelcards_version = importlib_metadata.version("modelcards") + logger.debug(f"Successfully imported modelcards version {_modelcards_version}") +except importlib_metadata.PackageNotFoundError: + _modelcards_available = False + + +_onnx_available = importlib.util.find_spec("onnxruntime") is not None +try: + _onnxruntime_version = importlib_metadata.version("onnxruntime") + logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") +except importlib_metadata.PackageNotFoundError: + _onnx_available = False + + +_scipy_available = importlib.util.find_spec("scipy") is not None +try: + _scipy_version = importlib_metadata.version("scipy") + logger.debug(f"Successfully imported transformers version {_scipy_version}") +except importlib_metadata.PackageNotFoundError: + _scipy_available = False + + +def is_torch_available(): + return _torch_available + + +def is_tf_available(): + return _tf_available + + +def is_flax_available(): + return _flax_available + + +def is_transformers_available(): + return _transformers_available + + +def is_inflect_available(): + return _inflect_available + + +def is_unidecode_available(): + return _unidecode_available + + +def is_modelcards_available(): + return _modelcards_available + + +def is_onnx_available(): + return _onnx_available + + +def is_scipy_available(): + return _scipy_available + + +# docstyle-ignore +FLAX_IMPORT_ERROR = """ +{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the +installation page: https://github.com/google/flax and follow the ones that match your environment. +""" + +# docstyle-ignore +INFLECT_IMPORT_ERROR = """ +{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install +inflect` +""" + +# docstyle-ignore +PYTORCH_IMPORT_ERROR = """ +{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the +installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. +""" + +# docstyle-ignore +ONNX_IMPORT_ERROR = """ +{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip +install onnxruntime` +""" + +# docstyle-ignore +SCIPY_IMPORT_ERROR = """ +{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install +scipy` +""" + +# docstyle-ignore +TENSORFLOW_IMPORT_ERROR = """ +{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the +installation page: https://www.tensorflow.org/install and follow the ones that match your environment. +""" + +# docstyle-ignore +TRANSFORMERS_IMPORT_ERROR = """ +{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip +install transformers` +""" + +# docstyle-ignore +UNIDECODE_IMPORT_ERROR = """ +{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install +Unidecode` +""" + + +BACKENDS_MAPPING = OrderedDict( + [ + ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), + ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), + ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), + ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), + ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), + ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), + ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), + ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), + ] +) + + +def requires_backends(obj, backends): + if not isinstance(backends, (list, tuple)): + backends = [backends] + + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + checks = (BACKENDS_MAPPING[backend] for backend in backends) + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: + raise ImportError("".join(failed)) + + +class DummyObject(type): + """ + Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by + `requires_backend` each time a user tries to access any method of that class. + """ + + def __getattr__(cls, key): + if key.startswith("_"): + return super().__getattr__(cls, key) + requires_backends(cls, cls._backends) diff --git a/diffusers/utils/logging.py b/diffusers/utils/logging.py new file mode 100644 index 000000000..1f2d0227b --- /dev/null +++ b/diffusers/utils/logging.py @@ -0,0 +1,344 @@ +# coding=utf-8 +# Copyright 2020 Optuna, Hugging Face +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Logging utilities.""" + +import logging +import os +import sys +import threading +from logging import CRITICAL # NOQA +from logging import DEBUG # NOQA +from logging import ERROR # NOQA +from logging import FATAL # NOQA +from logging import INFO # NOQA +from logging import NOTSET # NOQA +from logging import WARN # NOQA +from logging import WARNING # NOQA +from typing import Optional + +from tqdm import auto as tqdm_lib + + +_lock = threading.Lock() +_default_handler: Optional[logging.Handler] = None + +log_levels = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +_default_log_level = logging.WARNING + +_tqdm_active = True + + +def _get_default_logging_level(): + """ + If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is + not - fall back to `_default_log_level` + """ + env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None) + if env_level_str: + if env_level_str in log_levels: + return log_levels[env_level_str] + else: + logging.getLogger().warning( + f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " + f"has to be one of: { ', '.join(log_levels.keys()) }" + ) + return _default_log_level + + +def _get_library_name() -> str: + + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_get_default_logging_level()) + library_root_logger.propagate = False + + +def _reset_library_root_logger() -> None: + + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def get_log_levels_dict(): + return log_levels + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Return a logger with the specified name. + + This function is not supposed to be directly accessed unless you are writing a custom diffusers module. + """ + + if name is None: + name = _get_library_name() + + _configure_library_root_logger() + return logging.getLogger(name) + + +def get_verbosity() -> int: + """ + Return the current level for the 🤗 Diffusers' root logger as an int. + + Returns: + `int`: The logging level. + + + + 🤗 Diffusers has following logging levels: + + - 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - 40: `diffusers.logging.ERROR` + - 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - 20: `diffusers.logging.INFO` + - 10: `diffusers.logging.DEBUG` + + """ + + _configure_library_root_logger() + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """ + Set the verbosity level for the 🤗 Diffusers' root logger. + + Args: + verbosity (`int`): + Logging level, e.g., one of: + + - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - `diffusers.logging.ERROR` + - `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - `diffusers.logging.INFO` + - `diffusers.logging.DEBUG` + """ + + _configure_library_root_logger() + _get_library_root_logger().setLevel(verbosity) + + +def set_verbosity_info(): + """Set the verbosity to the `INFO` level.""" + return set_verbosity(INFO) + + +def set_verbosity_warning(): + """Set the verbosity to the `WARNING` level.""" + return set_verbosity(WARNING) + + +def set_verbosity_debug(): + """Set the verbosity to the `DEBUG` level.""" + return set_verbosity(DEBUG) + + +def set_verbosity_error(): + """Set the verbosity to the `ERROR` level.""" + return set_verbosity(ERROR) + + +def disable_default_handler() -> None: + """Disable the default handler of the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().removeHandler(_default_handler) + + +def enable_default_handler() -> None: + """Enable the default handler of the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().addHandler(_default_handler) + + +def add_handler(handler: logging.Handler) -> None: + """adds a handler to the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert handler is not None + _get_library_root_logger().addHandler(handler) + + +def remove_handler(handler: logging.Handler) -> None: + """removes given handler from the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert handler is not None and handler not in _get_library_root_logger().handlers + _get_library_root_logger().removeHandler(handler) + + +def disable_propagation() -> None: + """ + Disable propagation of the library log outputs. Note that log propagation is disabled by default. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """ + Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent + double logging if the root logger has been configured. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = True + + +def enable_explicit_format() -> None: + """ + Enable explicit formatting for every HuggingFace Diffusers' logger. The explicit formatter is as follows: + ``` + [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE + ``` + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") + handler.setFormatter(formatter) + + +def reset_format() -> None: + """ + Resets the formatting for HuggingFace Diffusers' loggers. + + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + handler.setFormatter(None) + + +def warning_advice(self, *args, **kwargs): + """ + This method is identical to `logger.warninging()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this + warning will not be printed + """ + no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False) + if no_advisory_warnings: + return + self.warning(*args, **kwargs) + + +logging.Logger.warning_advice = warning_advice + + +class EmptyTqdm: + """Dummy tqdm which doesn't do anything.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + self._iterator = args[0] if args else None + + def __iter__(self): + return iter(self._iterator) + + def __getattr__(self, _): + """Return empty function.""" + + def empty_fn(*args, **kwargs): # pylint: disable=unused-argument + return + + return empty_fn + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + return + + +class _tqdm_cls: + def __call__(self, *args, **kwargs): + if _tqdm_active: + return tqdm_lib.tqdm(*args, **kwargs) + else: + return EmptyTqdm(*args, **kwargs) + + def set_lock(self, *args, **kwargs): + self._lock = None + if _tqdm_active: + return tqdm_lib.tqdm.set_lock(*args, **kwargs) + + def get_lock(self): + if _tqdm_active: + return tqdm_lib.tqdm.get_lock() + + +tqdm = _tqdm_cls() + + +def is_progress_bar_enabled() -> bool: + """Return a boolean indicating whether tqdm progress bars are enabled.""" + global _tqdm_active + return bool(_tqdm_active) + + +def enable_progress_bar(): + """Enable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = True + + +def disable_progress_bar(): + """Disable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = False diff --git a/diffusers/utils/outputs.py b/diffusers/utils/outputs.py new file mode 100644 index 000000000..b02f62d02 --- /dev/null +++ b/diffusers/utils/outputs.py @@ -0,0 +1,109 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Generic utilities +""" + +import warnings +from collections import OrderedDict +from dataclasses import fields +from typing import Any, Tuple + +import numpy as np + +from .import_utils import is_torch_available + + +def is_tensor(x): + """ + Tests if `x` is a `torch.Tensor` or `np.ndarray`. + """ + if is_torch_available(): + import torch + + if isinstance(x, torch.Tensor): + return True + + return isinstance(x, np.ndarray) + + +class BaseOutput(OrderedDict): + """ + Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a + tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular + python dictionary. + + + + You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple + before. + + + """ + + def __post_init__(self): + class_fields = fields(self) + + # Safety and consistency checks + if not len(class_fields): + raise ValueError(f"{self.__class__.__name__} has no fields.") + + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = {k: v for (k, v) in self.items()} + if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample": + warnings.warn( + "The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or" + " `'images'` instead.", + DeprecationWarning, + ) + return inner_dict["images"] + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + + def to_tuple(self) -> Tuple[Any]: + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(self[k] for k in self.keys()) From e77c5d2d58ad5131607965c2c57d1ba23af00ab9 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Sun, 2 Oct 2022 16:54:53 +0200 Subject: [PATCH 61/76] fix --- local_diffusers/__init__.py | 60 - local_diffusers/commands/__init__.py | 27 - local_diffusers/commands/diffusers_cli.py | 41 - local_diffusers/commands/env.py | 70 - local_diffusers/configuration_utils.py | 403 ----- local_diffusers/dependency_versions_check.py | 47 - local_diffusers/dependency_versions_table.py | 26 - local_diffusers/dynamic_modules_utils.py | 335 ---- local_diffusers/hub_utils.py | 197 --- local_diffusers/modeling_utils.py | 542 ------ local_diffusers/models/__init__.py | 17 - local_diffusers/models/attention.py | 333 ---- local_diffusers/models/embeddings.py | 115 -- local_diffusers/models/resnet.py | 483 ------ local_diffusers/models/unet_2d.py | 246 --- local_diffusers/models/unet_2d_condition.py | 270 --- local_diffusers/models/unet_blocks.py | 1481 ----------------- local_diffusers/models/vae.py | 581 ------- local_diffusers/onnx_utils.py | 189 --- local_diffusers/optimization.py | 275 --- local_diffusers/pipeline_utils.py | 417 ----- local_diffusers/pipelines/__init__.py | 19 - local_diffusers/pipelines/ddim/__init__.py | 2 - .../pipelines/ddim/pipeline_ddim.py | 117 -- local_diffusers/pipelines/ddpm/__init__.py | 2 - .../pipelines/ddpm/pipeline_ddpm.py | 106 -- .../pipelines/latent_diffusion/__init__.py | 6 - .../pipeline_latent_diffusion.py | 705 -------- .../latent_diffusion_uncond/__init__.py | 2 - .../pipeline_latent_diffusion_uncond.py | 108 -- local_diffusers/pipelines/pndm/__init__.py | 2 - .../pipelines/pndm/pipeline_pndm.py | 111 -- .../pipelines/score_sde_ve/__init__.py | 2 - .../score_sde_ve/pipeline_score_sde_ve.py | 101 -- .../pipelines/stable_diffusion/__init__.py | 37 - .../pipeline_stable_diffusion.py | 397 ----- .../pipeline_stable_diffusion_img2img.py | 291 ---- .../pipeline_stable_diffusion_inpaint.py | 309 ---- .../pipeline_stable_diffusion_onnx.py | 165 -- .../stable_diffusion/safety_checker.py | 106 -- .../stochastic_karras_ve/__init__.py | 2 - .../pipeline_stochastic_karras_ve.py | 129 -- local_diffusers/schedulers/__init__.py | 28 - local_diffusers/schedulers/scheduling_ddim.py | 261 --- local_diffusers/schedulers/scheduling_ddpm.py | 264 --- .../schedulers/scheduling_karras_ve.py | 208 --- .../schedulers/scheduling_lms_discrete.py | 193 --- local_diffusers/schedulers/scheduling_pndm.py | 378 ----- .../schedulers/scheduling_sde_ve.py | 283 ---- .../schedulers/scheduling_sde_vp.py | 81 - .../schedulers/scheduling_utils.py | 125 -- local_diffusers/testing_utils.py | 61 - local_diffusers/training_utils.py | 125 -- local_diffusers/utils/__init__.py | 53 - local_diffusers/utils/dummy_scipy_objects.py | 11 - ...rmers_and_inflect_and_unidecode_objects.py | 10 - .../dummy_transformers_and_onnx_objects.py | 11 - .../utils/dummy_transformers_objects.py | 32 - local_diffusers/utils/import_utils.py | 274 --- local_diffusers/utils/logging.py | 344 ---- local_diffusers/utils/outputs.py | 109 -- minisd.py | 18 +- 62 files changed, 10 insertions(+), 11733 deletions(-) delete mode 100644 local_diffusers/__init__.py delete mode 100644 local_diffusers/commands/__init__.py delete mode 100644 local_diffusers/commands/diffusers_cli.py delete mode 100644 local_diffusers/commands/env.py delete mode 100644 local_diffusers/configuration_utils.py delete mode 100644 local_diffusers/dependency_versions_check.py delete mode 100644 local_diffusers/dependency_versions_table.py delete mode 100644 local_diffusers/dynamic_modules_utils.py delete mode 100644 local_diffusers/hub_utils.py delete mode 100644 local_diffusers/modeling_utils.py delete mode 100644 local_diffusers/models/__init__.py delete mode 100644 local_diffusers/models/attention.py delete mode 100644 local_diffusers/models/embeddings.py delete mode 100644 local_diffusers/models/resnet.py delete mode 100644 local_diffusers/models/unet_2d.py delete mode 100644 local_diffusers/models/unet_2d_condition.py delete mode 100644 local_diffusers/models/unet_blocks.py delete mode 100644 local_diffusers/models/vae.py delete mode 100644 local_diffusers/onnx_utils.py delete mode 100644 local_diffusers/optimization.py delete mode 100644 local_diffusers/pipeline_utils.py delete mode 100644 local_diffusers/pipelines/__init__.py delete mode 100644 local_diffusers/pipelines/ddim/__init__.py delete mode 100644 local_diffusers/pipelines/ddim/pipeline_ddim.py delete mode 100644 local_diffusers/pipelines/ddpm/__init__.py delete mode 100644 local_diffusers/pipelines/ddpm/pipeline_ddpm.py delete mode 100644 local_diffusers/pipelines/latent_diffusion/__init__.py delete mode 100644 local_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py delete mode 100644 local_diffusers/pipelines/latent_diffusion_uncond/__init__.py delete mode 100644 local_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py delete mode 100644 local_diffusers/pipelines/pndm/__init__.py delete mode 100644 local_diffusers/pipelines/pndm/pipeline_pndm.py delete mode 100644 local_diffusers/pipelines/score_sde_ve/__init__.py delete mode 100644 local_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py delete mode 100644 local_diffusers/pipelines/stable_diffusion/__init__.py delete mode 100644 local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py delete mode 100644 local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py delete mode 100644 local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py delete mode 100644 local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py delete mode 100644 local_diffusers/pipelines/stable_diffusion/safety_checker.py delete mode 100644 local_diffusers/pipelines/stochastic_karras_ve/__init__.py delete mode 100644 local_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py delete mode 100644 local_diffusers/schedulers/__init__.py delete mode 100644 local_diffusers/schedulers/scheduling_ddim.py delete mode 100644 local_diffusers/schedulers/scheduling_ddpm.py delete mode 100644 local_diffusers/schedulers/scheduling_karras_ve.py delete mode 100644 local_diffusers/schedulers/scheduling_lms_discrete.py delete mode 100644 local_diffusers/schedulers/scheduling_pndm.py delete mode 100644 local_diffusers/schedulers/scheduling_sde_ve.py delete mode 100644 local_diffusers/schedulers/scheduling_sde_vp.py delete mode 100644 local_diffusers/schedulers/scheduling_utils.py delete mode 100644 local_diffusers/testing_utils.py delete mode 100644 local_diffusers/training_utils.py delete mode 100644 local_diffusers/utils/__init__.py delete mode 100644 local_diffusers/utils/dummy_scipy_objects.py delete mode 100644 local_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py delete mode 100644 local_diffusers/utils/dummy_transformers_and_onnx_objects.py delete mode 100644 local_diffusers/utils/dummy_transformers_objects.py delete mode 100644 local_diffusers/utils/import_utils.py delete mode 100644 local_diffusers/utils/logging.py delete mode 100644 local_diffusers/utils/outputs.py diff --git a/local_diffusers/__init__.py b/local_diffusers/__init__.py deleted file mode 100644 index bf2f183c9..000000000 --- a/local_diffusers/__init__.py +++ /dev/null @@ -1,60 +0,0 @@ -from .utils import ( - is_inflect_available, - is_onnx_available, - is_scipy_available, - is_transformers_available, - is_unidecode_available, -) - - -__version__ = "0.3.0" - -from .configuration_utils import ConfigMixin -from .modeling_utils import ModelMixin -from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel -from .onnx_utils import OnnxRuntimeModel -from .optimization import ( - get_constant_schedule, - get_constant_schedule_with_warmup, - get_cosine_schedule_with_warmup, - get_cosine_with_hard_restarts_schedule_with_warmup, - get_linear_schedule_with_warmup, - get_polynomial_decay_schedule_with_warmup, - get_scheduler, -) -from .pipeline_utils import DiffusionPipeline -from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline -from .schedulers import ( - DDIMScheduler, - DDPMScheduler, - KarrasVeScheduler, - PNDMScheduler, - SchedulerMixin, - ScoreSdeVeScheduler, -) -from .utils import logging - - -if is_scipy_available(): - from .schedulers import LMSDiscreteScheduler -else: - from .utils.dummy_scipy_objects import * # noqa F403 - -from .training_utils import EMAModel - - -if is_transformers_available(): - from .pipelines import ( - LDMTextToImagePipeline, - StableDiffusionImg2ImgPipeline, - StableDiffusionInpaintPipeline, - StableDiffusionPipeline, - ) -else: - from .utils.dummy_transformers_objects import * # noqa F403 - - -if is_transformers_available() and is_onnx_available(): - from .pipelines import StableDiffusionOnnxPipeline -else: - from .utils.dummy_transformers_and_onnx_objects import * # noqa F403 diff --git a/local_diffusers/commands/__init__.py b/local_diffusers/commands/__init__.py deleted file mode 100644 index 902bd46ce..000000000 --- a/local_diffusers/commands/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from argparse import ArgumentParser - - -class BaseDiffusersCLICommand(ABC): - @staticmethod - @abstractmethod - def register_subcommand(parser: ArgumentParser): - raise NotImplementedError() - - @abstractmethod - def run(self): - raise NotImplementedError() diff --git a/local_diffusers/commands/diffusers_cli.py b/local_diffusers/commands/diffusers_cli.py deleted file mode 100644 index 30084e55b..000000000 --- a/local_diffusers/commands/diffusers_cli.py +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env python -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from argparse import ArgumentParser - -from .env import EnvironmentCommand - - -def main(): - parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []") - commands_parser = parser.add_subparsers(help="diffusers-cli command helpers") - - # Register commands - EnvironmentCommand.register_subcommand(commands_parser) - - # Let's go - args = parser.parse_args() - - if not hasattr(args, "func"): - parser.print_help() - exit(1) - - # Run - service = args.func(args) - service.run() - - -if __name__ == "__main__": - main() diff --git a/local_diffusers/commands/env.py b/local_diffusers/commands/env.py deleted file mode 100644 index 81a878bff..000000000 --- a/local_diffusers/commands/env.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import platform -from argparse import ArgumentParser - -import huggingface_hub - -from .. import __version__ as version -from ..utils import is_torch_available, is_transformers_available -from . import BaseDiffusersCLICommand - - -def info_command_factory(_): - return EnvironmentCommand() - - -class EnvironmentCommand(BaseDiffusersCLICommand): - @staticmethod - def register_subcommand(parser: ArgumentParser): - download_parser = parser.add_parser("env") - download_parser.set_defaults(func=info_command_factory) - - def run(self): - hub_version = huggingface_hub.__version__ - - pt_version = "not installed" - pt_cuda_available = "NA" - if is_torch_available(): - import torch - - pt_version = torch.__version__ - pt_cuda_available = torch.cuda.is_available() - - transformers_version = "not installed" - if is_transformers_available: - import transformers - - transformers_version = transformers.__version__ - - info = { - "`diffusers` version": version, - "Platform": platform.platform(), - "Python version": platform.python_version(), - "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", - "Huggingface_hub version": hub_version, - "Transformers version": transformers_version, - "Using GPU in script?": "", - "Using distributed or parallel set-up in script?": "", - } - - print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") - print(self.format_dict(info)) - - return info - - @staticmethod - def format_dict(d): - return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" diff --git a/local_diffusers/configuration_utils.py b/local_diffusers/configuration_utils.py deleted file mode 100644 index fbe75f3f1..000000000 --- a/local_diffusers/configuration_utils.py +++ /dev/null @@ -1,403 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" ConfigMixinuration base class and utilities.""" -import functools -import inspect -import json -import os -import re -from collections import OrderedDict -from typing import Any, Dict, Tuple, Union - -from huggingface_hub import hf_hub_download -from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError -from requests import HTTPError - -from . import __version__ -from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging - - -logger = logging.get_logger(__name__) - -_re_configuration_file = re.compile(r"config\.(.*)\.json") - - -class ConfigMixin: - r""" - Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all - methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with - - [`~ConfigMixin.from_config`] - - [`~ConfigMixin.save_config`] - - Class attributes: - - **config_name** (`str`) -- A filename under which the config should stored when calling - [`~ConfigMixin.save_config`] (should be overriden by parent class). - - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be - overriden by parent class). - """ - config_name = None - ignore_for_config = [] - - def register_to_config(self, **kwargs): - if self.config_name is None: - raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") - kwargs["_class_name"] = self.__class__.__name__ - kwargs["_diffusers_version"] = __version__ - - for key, value in kwargs.items(): - try: - setattr(self, key, value) - except AttributeError as err: - logger.error(f"Can't set {key} with value {value} for {self}") - raise err - - if not hasattr(self, "_internal_dict"): - internal_dict = kwargs - else: - previous_dict = dict(self._internal_dict) - internal_dict = {**self._internal_dict, **kwargs} - logger.debug(f"Updating config from {previous_dict} to {internal_dict}") - - self._internal_dict = FrozenDict(internal_dict) - - def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): - """ - Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the - [`~ConfigMixin.from_config`] class method. - - Args: - save_directory (`str` or `os.PathLike`): - Directory where the configuration JSON file will be saved (will be created if it does not exist). - """ - if os.path.isfile(save_directory): - raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") - - os.makedirs(save_directory, exist_ok=True) - - # If we save using the predefined names, we can load using `from_config` - output_config_file = os.path.join(save_directory, self.config_name) - - self.to_json_file(output_config_file) - logger.info(f"ConfigMixinuration saved in {output_config_file}") - - @classmethod - def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): - r""" - Instantiate a Python class from a pre-defined JSON-file. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an - organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g., - `./my_model_directory/`. - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): - Whether or not to raise an error if some of the weights from the checkpoint do not have the same size - as the weights of the model (if for instance, you are instantiating a model with 10 labels from a - checkpoint with 3 labels). - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. - - - - Passing `use_auth_token=True`` is required when you want to use a private model. - - - - - - Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to - use this method in a firewalled environment. - - - - """ - config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) - - init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) - - model = cls(**init_dict) - - if return_unused_kwargs: - return model, unused_kwargs - else: - return model - - @classmethod - def get_config_dict( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - use_auth_token = kwargs.pop("use_auth_token", None) - local_files_only = kwargs.pop("local_files_only", False) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - - user_agent = {"file_type": "config"} - - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - - if cls.config_name is None: - raise ValueError( - "`self.config_name` is not defined. Note that one should not load a config from " - "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" - ) - - if os.path.isfile(pretrained_model_name_or_path): - config_file = pretrained_model_name_or_path - elif os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): - # Load from a PyTorch checkpoint - config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) - elif subfolder is not None and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) - ): - config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) - else: - raise EnvironmentError( - f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." - ) - else: - try: - # Load from URL or cache if already cached - config_file = hf_hub_download( - pretrained_model_name_or_path, - filename=cls.config_name, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, - subfolder=subfolder, - revision=revision, - ) - - except RepositoryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" - " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" - " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli" - " login` and pass `use_auth_token=True`." - ) - except RevisionNotFoundError: - raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for" - " this model name. Check the model page at" - f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." - ) - except EntryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}." - ) - except HTTPError as err: - raise EnvironmentError( - "There was a specific connection error when trying to load" - f" {pretrained_model_name_or_path}:\n{err}" - ) - except ValueError: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" - f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" - f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to" - " run the library in offline mode at" - " 'https://huggingface.co/docs/diffusers/installation#offline-mode'." - ) - except EnvironmentError: - raise EnvironmentError( - f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " - "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " - f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing a {cls.config_name} file" - ) - - try: - # Load config dict - config_dict = cls._dict_from_json_file(config_file) - except (json.JSONDecodeError, UnicodeDecodeError): - raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") - - return config_dict - - @classmethod - def extract_init_dict(cls, config_dict, **kwargs): - expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) - expected_keys.remove("self") - # remove general kwargs if present in dict - if "kwargs" in expected_keys: - expected_keys.remove("kwargs") - # remove keys to be ignored - if len(cls.ignore_for_config) > 0: - expected_keys = expected_keys - set(cls.ignore_for_config) - init_dict = {} - for key in expected_keys: - if key in kwargs: - # overwrite key - init_dict[key] = kwargs.pop(key) - elif key in config_dict: - # use value from config dict - init_dict[key] = config_dict.pop(key) - - unused_kwargs = config_dict.update(kwargs) - - passed_keys = set(init_dict.keys()) - if len(expected_keys - passed_keys) > 0: - logger.warning( - f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." - ) - - return init_dict, unused_kwargs - - @classmethod - def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): - with open(json_file, "r", encoding="utf-8") as reader: - text = reader.read() - return json.loads(text) - - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" - - @property - def config(self) -> Dict[str, Any]: - return self._internal_dict - - def to_json_string(self) -> str: - """ - Serializes this instance to a JSON string. - - Returns: - `str`: String containing all the attributes that make up this configuration instance in JSON format. - """ - config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} - return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - - def to_json_file(self, json_file_path: Union[str, os.PathLike]): - """ - Save this instance to a JSON file. - - Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file in which this configuration instance's parameters will be saved. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - writer.write(self.to_json_string()) - - -class FrozenDict(OrderedDict): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - for key, value in self.items(): - setattr(self, key, value) - - self.__frozen = True - - def __delitem__(self, *args, **kwargs): - raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") - - def setdefault(self, *args, **kwargs): - raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") - - def pop(self, *args, **kwargs): - raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") - - def update(self, *args, **kwargs): - raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") - - def __setattr__(self, name, value): - if hasattr(self, "__frozen") and self.__frozen: - raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") - super().__setattr__(name, value) - - def __setitem__(self, name, value): - if hasattr(self, "__frozen") and self.__frozen: - raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") - super().__setitem__(name, value) - - -def register_to_config(init): - r""" - Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are - automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that - shouldn't be registered in the config, use the `ignore_for_config` class variable - - Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! - """ - - @functools.wraps(init) - def inner_init(self, *args, **kwargs): - # Ignore private kwargs in the init. - init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} - init(self, *args, **init_kwargs) - if not isinstance(self, ConfigMixin): - raise RuntimeError( - f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " - "not inherit from `ConfigMixin`." - ) - - ignore = getattr(self, "ignore_for_config", []) - # Get positional arguments aligned with kwargs - new_kwargs = {} - signature = inspect.signature(init) - parameters = { - name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore - } - for arg, name in zip(args, parameters.keys()): - new_kwargs[name] = arg - - # Then add all kwargs - new_kwargs.update( - { - k: init_kwargs.get(k, default) - for k, default in parameters.items() - if k not in ignore and k not in new_kwargs - } - ) - getattr(self, "register_to_config")(**new_kwargs) - - return inner_init diff --git a/local_diffusers/dependency_versions_check.py b/local_diffusers/dependency_versions_check.py deleted file mode 100644 index bbf863222..000000000 --- a/local_diffusers/dependency_versions_check.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import sys - -from .dependency_versions_table import deps -from .utils.versions import require_version, require_version_core - - -# define which module versions we always want to check at run time -# (usually the ones defined in `install_requires` in setup.py) -# -# order specific notes: -# - tqdm must be checked before tokenizers - -pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split() -if sys.version_info < (3, 7): - pkgs_to_check_at_runtime.append("dataclasses") -if sys.version_info < (3, 8): - pkgs_to_check_at_runtime.append("importlib_metadata") - -for pkg in pkgs_to_check_at_runtime: - if pkg in deps: - if pkg == "tokenizers": - # must be loaded here, or else tqdm check may fail - from .utils import is_tokenizers_available - - if not is_tokenizers_available(): - continue # not required, check version only if installed - - require_version_core(deps[pkg]) - else: - raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") - - -def dep_version_check(pkg, hint=None): - require_version(deps[pkg], hint) diff --git a/local_diffusers/dependency_versions_table.py b/local_diffusers/dependency_versions_table.py deleted file mode 100644 index 74c5331e5..000000000 --- a/local_diffusers/dependency_versions_table.py +++ /dev/null @@ -1,26 +0,0 @@ -# THIS FILE HAS BEEN AUTOGENERATED. To update: -# 1. modify the `_deps` dict in setup.py -# 2. run `make deps_table_update`` -deps = { - "Pillow": "Pillow", - "accelerate": "accelerate>=0.11.0", - "black": "black==22.3", - "datasets": "datasets", - "filelock": "filelock", - "flake8": "flake8>=3.8.3", - "hf-doc-builder": "hf-doc-builder>=0.3.0", - "huggingface-hub": "huggingface-hub>=0.8.1", - "importlib_metadata": "importlib_metadata", - "isort": "isort>=5.5.4", - "modelcards": "modelcards==0.1.4", - "numpy": "numpy", - "pytest": "pytest", - "pytest-timeout": "pytest-timeout", - "pytest-xdist": "pytest-xdist", - "scipy": "scipy", - "regex": "regex!=2019.12.17", - "requests": "requests", - "tensorboard": "tensorboard", - "torch": "torch>=1.4", - "transformers": "transformers>=4.21.0", -} diff --git a/local_diffusers/dynamic_modules_utils.py b/local_diffusers/dynamic_modules_utils.py deleted file mode 100644 index 0ebf916e7..000000000 --- a/local_diffusers/dynamic_modules_utils.py +++ /dev/null @@ -1,335 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities to dynamically load objects from the Hub.""" - -import importlib -import os -import re -import shutil -import sys -from pathlib import Path -from typing import Dict, Optional, Union - -from huggingface_hub import cached_download - -from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def init_hf_modules(): - """ - Creates the cache directory for modules with an init, and adds it to the Python path. - """ - # This function has already been executed if HF_MODULES_CACHE already is in the Python path. - if HF_MODULES_CACHE in sys.path: - return - - sys.path.append(HF_MODULES_CACHE) - os.makedirs(HF_MODULES_CACHE, exist_ok=True) - init_path = Path(HF_MODULES_CACHE) / "__init__.py" - if not init_path.exists(): - init_path.touch() - - -def create_dynamic_module(name: Union[str, os.PathLike]): - """ - Creates a dynamic module in the cache directory for modules. - """ - init_hf_modules() - dynamic_module_path = Path(HF_MODULES_CACHE) / name - # If the parent module does not exist yet, recursively create it. - if not dynamic_module_path.parent.exists(): - create_dynamic_module(dynamic_module_path.parent) - os.makedirs(dynamic_module_path, exist_ok=True) - init_path = dynamic_module_path / "__init__.py" - if not init_path.exists(): - init_path.touch() - - -def get_relative_imports(module_file): - """ - Get the list of modules that are relatively imported in a module file. - - Args: - module_file (`str` or `os.PathLike`): The module file to inspect. - """ - with open(module_file, "r", encoding="utf-8") as f: - content = f.read() - - # Imports of the form `import .xxx` - relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) - # Imports of the form `from .xxx import yyy` - relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) - # Unique-ify - return list(set(relative_imports)) - - -def get_relative_import_files(module_file): - """ - Get the list of all files that are needed for a given module. Note that this function recurses through the relative - imports (if a imports b and b imports c, it will return module files for b and c). - - Args: - module_file (`str` or `os.PathLike`): The module file to inspect. - """ - no_change = False - files_to_check = [module_file] - all_relative_imports = [] - - # Let's recurse through all relative imports - while not no_change: - new_imports = [] - for f in files_to_check: - new_imports.extend(get_relative_imports(f)) - - module_path = Path(module_file).parent - new_import_files = [str(module_path / m) for m in new_imports] - new_import_files = [f for f in new_import_files if f not in all_relative_imports] - files_to_check = [f"{f}.py" for f in new_import_files] - - no_change = len(new_import_files) == 0 - all_relative_imports.extend(files_to_check) - - return all_relative_imports - - -def check_imports(filename): - """ - Check if the current Python environment contains all the libraries that are imported in a file. - """ - with open(filename, "r", encoding="utf-8") as f: - content = f.read() - - # Imports of the form `import xxx` - imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) - # Imports of the form `from xxx import yyy` - imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) - # Only keep the top-level module - imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] - - # Unique-ify and test we got them all - imports = list(set(imports)) - missing_packages = [] - for imp in imports: - try: - importlib.import_module(imp) - except ImportError: - missing_packages.append(imp) - - if len(missing_packages) > 0: - raise ImportError( - "This modeling file requires the following packages that were not found in your environment: " - f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" - ) - - return get_relative_imports(filename) - - -def get_class_in_module(class_name, module_path): - """ - Import a module on the cache directory for modules and extract a class from it. - """ - module_path = module_path.replace(os.path.sep, ".") - module = importlib.import_module(module_path) - return getattr(module, class_name) - - -def get_cached_module_file( - pretrained_model_name_or_path: Union[str, os.PathLike], - module_file: str, - cache_dir: Optional[Union[str, os.PathLike]] = None, - force_download: bool = False, - resume_download: bool = False, - proxies: Optional[Dict[str, str]] = None, - use_auth_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - local_files_only: bool = False, -): - """ - Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached - Transformers module. - - Args: - pretrained_model_name_or_path (`str` or `os.PathLike`): - This can be either: - - - a string, the *model id* of a pretrained model configuration hosted inside a model repo on - huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced - under a user or organization name, like `dbmdz/bert-base-german-cased`. - - a path to a *directory* containing a configuration file saved using the - [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. - - module_file (`str`): - The name of the module file containing the class to look for. - cache_dir (`str` or `os.PathLike`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the standard - cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force to (re-)download the configuration files and override the cached versions if they - exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - local_files_only (`bool`, *optional*, defaults to `False`): - If `True`, will only try to load the tokenizer configuration from local files. - - - - Passing `use_auth_token=True` is required when you want to use a private model. - - - - Returns: - `str`: The path to the module inside the cache. - """ - # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) - submodule = "local" - - if os.path.isfile(module_file_or_url): - resolved_module_file = module_file_or_url - else: - try: - # Load from URL or cache if already cached - resolved_module_file = cached_download( - module_file_or_url, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - ) - - except EnvironmentError: - logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") - raise - - # Check we have all the requirements in our environment - modules_needed = check_imports(resolved_module_file) - - # Now we move the module inside our cached dynamic modules. - full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule - create_dynamic_module(full_submodule) - submodule_path = Path(HF_MODULES_CACHE) / full_submodule - # We always copy local files (we could hash the file to see if there was a change, and give them the name of - # that hash, to only copy when there is a modification but it seems overkill for now). - # The only reason we do the copy is to avoid putting too many folders in sys.path. - shutil.copy(resolved_module_file, submodule_path / module_file) - for module_needed in modules_needed: - module_needed = f"{module_needed}.py" - shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) - return os.path.join(full_submodule, module_file) - - -def get_class_from_dynamic_module( - pretrained_model_name_or_path: Union[str, os.PathLike], - module_file: str, - class_name: str, - cache_dir: Optional[Union[str, os.PathLike]] = None, - force_download: bool = False, - resume_download: bool = False, - proxies: Optional[Dict[str, str]] = None, - use_auth_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - local_files_only: bool = False, - **kwargs, -): - """ - Extracts a class from a module file, present in the local folder or repository of a model. - - - - Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should - therefore only be called on trusted repos. - - - - Args: - pretrained_model_name_or_path (`str` or `os.PathLike`): - This can be either: - - - a string, the *model id* of a pretrained model configuration hosted inside a model repo on - huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced - under a user or organization name, like `dbmdz/bert-base-german-cased`. - - a path to a *directory* containing a configuration file saved using the - [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. - - module_file (`str`): - The name of the module file containing the class to look for. - class_name (`str`): - The name of the class to import in the module. - cache_dir (`str` or `os.PathLike`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the standard - cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force to (re-)download the configuration files and override the cached versions if they - exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. - use_auth_token (`str` or `bool`, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - local_files_only (`bool`, *optional*, defaults to `False`): - If `True`, will only try to load the tokenizer configuration from local files. - - - - Passing `use_auth_token=True` is required when you want to use a private model. - - - - Returns: - `type`: The class, dynamically imported from the module. - - Examples: - - ```python - # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this - # module. - cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") - ```""" - # And lastly we get the class inside our newly created module - final_module = get_cached_module_file( - pretrained_model_name_or_path, - module_file, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - use_auth_token=use_auth_token, - revision=revision, - local_files_only=local_files_only, - ) - return get_class_in_module(class_name, final_module.replace(".py", "")) diff --git a/local_diffusers/hub_utils.py b/local_diffusers/hub_utils.py deleted file mode 100644 index c07329e36..000000000 --- a/local_diffusers/hub_utils.py +++ /dev/null @@ -1,197 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os -import shutil -from pathlib import Path -from typing import Optional - -from huggingface_hub import HfFolder, Repository, whoami - -from .pipeline_utils import DiffusionPipeline -from .utils import is_modelcards_available, logging - - -if is_modelcards_available(): - from modelcards import CardData, ModelCard - - -logger = logging.get_logger(__name__) - - -MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md" - - -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - -def init_git_repo(args, at_init: bool = False): - """ - Args: - Initializes a git repo in `args.hub_model_id`. - at_init (`bool`, *optional*, defaults to `False`): - Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True` - and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out. - """ - if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: - return - hub_token = args.hub_token if hasattr(args, "hub_token") else None - use_auth_token = True if hub_token is None else hub_token - if not hasattr(args, "hub_model_id") or args.hub_model_id is None: - repo_name = Path(args.output_dir).absolute().name - else: - repo_name = args.hub_model_id - if "/" not in repo_name: - repo_name = get_full_repo_name(repo_name, token=hub_token) - - try: - repo = Repository( - args.output_dir, - clone_from=repo_name, - use_auth_token=use_auth_token, - private=args.hub_private_repo, - ) - except EnvironmentError: - if args.overwrite_output_dir and at_init: - # Try again after wiping output_dir - shutil.rmtree(args.output_dir) - repo = Repository( - args.output_dir, - clone_from=repo_name, - use_auth_token=use_auth_token, - ) - else: - raise - - repo.git_pull() - - # By default, ignore the checkpoint folders - if not os.path.exists(os.path.join(args.output_dir, ".gitignore")): - with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: - writer.writelines(["checkpoint-*/"]) - - return repo - - -def push_to_hub( - args, - pipeline: DiffusionPipeline, - repo: Repository, - commit_message: Optional[str] = "End of training", - blocking: bool = True, - **kwargs, -) -> str: - """ - Parameters: - Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. - commit_message (`str`, *optional*, defaults to `"End of training"`): - Message to commit while pushing. - blocking (`bool`, *optional*, defaults to `True`): - Whether the function should return only when the `git push` has finished. - kwargs: - Additional keyword arguments passed along to [`create_model_card`]. - Returns: - The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the - commit and an object to track the progress of the commit if `blocking=True` - """ - - if not hasattr(args, "hub_model_id") or args.hub_model_id is None: - model_name = Path(args.output_dir).name - else: - model_name = args.hub_model_id.split("/")[-1] - - output_dir = args.output_dir - os.makedirs(output_dir, exist_ok=True) - logger.info(f"Saving pipeline checkpoint to {output_dir}") - pipeline.save_pretrained(output_dir) - - # Only push from one node. - if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: - return - - # Cancel any async push in progress if blocking=True. The commits will all be pushed together. - if ( - blocking - and len(repo.command_queue) > 0 - and repo.command_queue[-1] is not None - and not repo.command_queue[-1].is_done - ): - repo.command_queue[-1]._process.kill() - - git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True) - # push separately the model card to be independent from the rest of the model - create_model_card(args, model_name=model_name) - try: - repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True) - except EnvironmentError as exc: - logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") - - return git_head_commit_url - - -def create_model_card(args, model_name): - if not is_modelcards_available: - raise ValueError( - "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can" - " install the package with `pip install modelcards`." - ) - - if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: - return - - hub_token = args.hub_token if hasattr(args, "hub_token") else None - repo_name = get_full_repo_name(model_name, token=hub_token) - - model_card = ModelCard.from_template( - card_data=CardData( # Card metadata object that will be converted to YAML block - language="en", - license="apache-2.0", - library_name="diffusers", - tags=[], - datasets=args.dataset_name, - metrics=[], - ), - template_path=MODEL_CARD_TEMPLATE_PATH, - model_name=model_name, - repo_name=repo_name, - dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None, - learning_rate=args.learning_rate, - train_batch_size=args.train_batch_size, - eval_batch_size=args.eval_batch_size, - gradient_accumulation_steps=args.gradient_accumulation_steps - if hasattr(args, "gradient_accumulation_steps") - else None, - adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None, - adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None, - adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None, - adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None, - lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None, - lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None, - ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None, - ema_power=args.ema_power if hasattr(args, "ema_power") else None, - ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None, - mixed_precision=args.mixed_precision, - ) - - card_path = os.path.join(args.output_dir, "README.md") - model_card.save(card_path) diff --git a/local_diffusers/modeling_utils.py b/local_diffusers/modeling_utils.py deleted file mode 100644 index fb613614a..000000000 --- a/local_diffusers/modeling_utils.py +++ /dev/null @@ -1,542 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from typing import Callable, List, Optional, Tuple, Union - -import torch -from torch import Tensor, device - -from huggingface_hub import hf_hub_download -from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError -from requests import HTTPError - -from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging - - -WEIGHTS_NAME = "diffusion_pytorch_model.bin" - - -logger = logging.get_logger(__name__) - - -def get_parameter_device(parameter: torch.nn.Module): - try: - return next(parameter.parameters()).device - except StopIteration: - # For torch.nn.DataParallel compatibility in PyTorch 1.5 - - def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = parameter._named_members(get_members_fn=find_tensor_attributes) - first_tuple = next(gen) - return first_tuple[1].device - - -def get_parameter_dtype(parameter: torch.nn.Module): - try: - return next(parameter.parameters()).dtype - except StopIteration: - # For torch.nn.DataParallel compatibility in PyTorch 1.5 - - def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = parameter._named_members(get_members_fn=find_tensor_attributes) - first_tuple = next(gen) - return first_tuple[1].dtype - - -def load_state_dict(checkpoint_file: Union[str, os.PathLike]): - """ - Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. - """ - try: - return torch.load(checkpoint_file, map_location="cpu") - except Exception as e: - try: - with open(checkpoint_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please install " - "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " - "you cloned." - ) - else: - raise ValueError( - f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " - "model. Make sure you have saved the model properly." - ) from e - except (UnicodeDecodeError, ValueError): - raise OSError( - f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' " - f"at '{checkpoint_file}'. " - "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." - ) - - -def _load_state_dict_into_model(model_to_load, state_dict): - # Convert old format to new format if needed from a PyTorch state_dict - # copy state_dict so _load_from_state_dict can modify it - state_dict = state_dict.copy() - error_msgs = [] - - # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants - # so we need to apply the function recursively. - def load(module: torch.nn.Module, prefix=""): - args = (state_dict, prefix, {}, True, [], [], error_msgs) - module._load_from_state_dict(*args) - - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + ".") - - load(model_to_load) - - return error_msgs - - -class ModelMixin(torch.nn.Module): - r""" - Base class for all models. - - [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading - and saving models. - - - **config_name** ([`str`]) -- A filename under which the model should be stored when calling - [`~modeling_utils.ModelMixin.save_pretrained`]. - """ - config_name = CONFIG_NAME - _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] - - def __init__(self): - super().__init__() - - def save_pretrained( - self, - save_directory: Union[str, os.PathLike], - is_main_process: bool = True, - save_function: Callable = torch.save, - ): - """ - Save a model and its configuration file to a directory, so that it can be re-loaded using the - `[`~modeling_utils.ModelMixin.from_pretrained`]` class method. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful when in distributed training like - TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on - the main process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful on distributed training like TPUs when one - need to replace `torch.save` by another method. - """ - if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - - os.makedirs(save_directory, exist_ok=True) - - model_to_save = self - - # Attach architecture to the config - # Save the config - if is_main_process: - model_to_save.save_config(save_directory) - - # Save the model - state_dict = model_to_save.state_dict() - - # Clean the folder from a previous save - for filename in os.listdir(save_directory): - full_filename = os.path.join(save_directory, filename) - # If we have a shard file that is not going to be replaced, we delete it, but only from the main process - # in distributed settings to avoid race conditions. - if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process: - os.remove(full_filename) - - # Save the model - save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME)) - - logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}") - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): - r""" - Instantiate a pretrained pytorch model from a pre-trained model configuration. - - The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train - the model, you should first set it back in training mode with `model.train()`. - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. - - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., - `./my_model_directory/`. - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `diffusers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. - - - - Passing `use_auth_token=True`` is required when you want to use a private model. - - - - - - Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use - this method in a firewalled environment. - - - - """ - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - output_loading_info = kwargs.pop("output_loading_info", False) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - from_auto_class = kwargs.pop("_from_auto", False) - torch_dtype = kwargs.pop("torch_dtype", None) - subfolder = kwargs.pop("subfolder", None) - - user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} - - # Load config if we don't provide a configuration - config_path = pretrained_model_name_or_path - model, unused_kwargs = cls.from_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - **kwargs, - ) - - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." - ) - elif torch_dtype is not None: - model = model.to(torch_dtype) - - model.register_to_config(_name_or_path=pretrained_model_name_or_path) - # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the - # Load model - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): - # Load from a PyTorch checkpoint - model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) - elif subfolder is not None and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) - ): - model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) - else: - raise EnvironmentError( - f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}." - ) - else: - try: - # Load from URL or cache if already cached - model_file = hf_hub_download( - pretrained_model_name_or_path, - filename=WEIGHTS_NAME, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, - subfolder=subfolder, - revision=revision, - ) - - except RepositoryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " - "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " - "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " - "login` and pass `use_auth_token=True`." - ) - except RevisionNotFoundError: - raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " - "this model name. Check the model page at " - f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." - ) - except EntryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}." - ) - except HTTPError as err: - raise EnvironmentError( - "There was a specific connection error when trying to load" - f" {pretrained_model_name_or_path}:\n{err}" - ) - except ValueError: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" - f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" - f" directory containing a file named {WEIGHTS_NAME} or" - " \nCheckout your internet connection or see how to run the library in" - " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." - ) - except EnvironmentError: - raise EnvironmentError( - f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " - "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " - f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing a file named {WEIGHTS_NAME}" - ) - - # restore default dtype - state_dict = load_state_dict(model_file) - model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( - model, - state_dict, - model_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - ) - - # Set model in evaluation mode to deactivate DropOut modules by default - model.eval() - - if output_loading_info: - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - "error_msgs": error_msgs, - } - return model, loading_info - - return model - - @classmethod - def _load_pretrained_model( - cls, - model, - state_dict, - resolved_archive_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=False, - ): - # Retrieve missing & unexpected_keys - model_state_dict = model.state_dict() - loaded_keys = [k for k in state_dict.keys()] - - expected_keys = list(model_state_dict.keys()) - - original_loaded_keys = loaded_keys - - missing_keys = list(set(expected_keys) - set(loaded_keys)) - unexpected_keys = list(set(loaded_keys) - set(expected_keys)) - - # Make sure we are able to load base models as well as derived models (with heads) - model_to_load = model - - def _find_mismatched_keys( - state_dict, - model_state_dict, - loaded_keys, - ignore_mismatched_sizes, - ): - mismatched_keys = [] - if ignore_mismatched_sizes: - for checkpoint_key in loaded_keys: - model_key = checkpoint_key - - if ( - model_key in model_state_dict - and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape - ): - mismatched_keys.append( - (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) - ) - del state_dict[checkpoint_key] - return mismatched_keys - - if state_dict is not None: - # Whole checkpoint - mismatched_keys = _find_mismatched_keys( - state_dict, - model_state_dict, - original_loaded_keys, - ignore_mismatched_sizes, - ) - error_msgs = _load_state_dict_into_model(model_to_load, state_dict) - - if len(error_msgs) > 0: - error_msg = "\n\t".join(error_msgs) - if "size mismatch" in error_msg: - error_msg += ( - "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." - ) - raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" - f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" - f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" - " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" - " BertForPreTraining model).\n- This IS NOT expected if you are initializing" - f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" - " identical (initializing a BertForSequenceClassification model from a" - " BertForSequenceClassification model)." - ) - else: - logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") - if len(missing_keys) > 0: - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" - " TRAIN this model on a down-stream task to be able to use it for predictions and inference." - ) - elif len(mismatched_keys) == 0: - logger.info( - f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" - f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" - f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" - " without further training." - ) - if len(mismatched_keys) > 0: - mismatched_warning = "\n".join( - [ - f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" - for key, shape1, shape2 in mismatched_keys - ] - ) - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" - f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" - " able to use it for predictions and inference." - ) - - return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs - - @property - def device(self) -> device: - """ - `torch.device`: The device on which the module is (assuming that all the module parameters are on the same - device). - """ - return get_parameter_device(self) - - @property - def dtype(self) -> torch.dtype: - """ - `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). - """ - return get_parameter_dtype(self) - - def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: - """ - Get number of (optionally, trainable or non-embeddings) parameters in the module. - - Args: - only_trainable (`bool`, *optional*, defaults to `False`): - Whether or not to return only the number of trainable parameters - - exclude_embeddings (`bool`, *optional*, defaults to `False`): - Whether or not to return only the number of non-embeddings parameters - - Returns: - `int`: The number of parameters. - """ - - if exclude_embeddings: - embedding_param_names = [ - f"{name}.weight" - for name, module_type in self.named_modules() - if isinstance(module_type, torch.nn.Embedding) - ] - non_embedding_parameters = [ - parameter for name, parameter in self.named_parameters() if name not in embedding_param_names - ] - return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) - else: - return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) - - -def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: - """ - Recursively unwraps a model from potential containers (as used in distributed training). - - Args: - model (`torch.nn.Module`): The model to unwrap. - """ - # since there could be multiple levels of wrapping, unwrap recursively - if hasattr(model, "module"): - return unwrap_model(model.module) - else: - return model diff --git a/local_diffusers/models/__init__.py b/local_diffusers/models/__init__.py deleted file mode 100644 index e0ac5c8d5..000000000 --- a/local_diffusers/models/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .unet_2d import UNet2DModel -from .unet_2d_condition import UNet2DConditionModel -from .vae import AutoencoderKL, VQModel diff --git a/local_diffusers/models/attention.py b/local_diffusers/models/attention.py deleted file mode 100644 index de9c92691..000000000 --- a/local_diffusers/models/attention.py +++ /dev/null @@ -1,333 +0,0 @@ -import math -from typing import Optional - -import torch -import torch.nn.functional as F -from torch import nn - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted - to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - Uses three q, k, v linear layers to compute attention. - - Parameters: - channels (:obj:`int`): The number of channels in the input and output. - num_head_channels (:obj:`int`, *optional*): - The number of channels in each head. If None, then `num_heads` = 1. - num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. - rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. - eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. - """ - - def __init__( - self, - channels: int, - num_head_channels: Optional[int] = None, - num_groups: int = 32, - rescale_output_factor: float = 1.0, - eps: float = 1e-5, - ): - super().__init__() - self.channels = channels - - self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 - self.num_head_size = num_head_channels - self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) - - # define q,k,v as linear layers - self.query = nn.Linear(channels, channels) - self.key = nn.Linear(channels, channels) - self.value = nn.Linear(channels, channels) - - self.rescale_output_factor = rescale_output_factor - self.proj_attn = nn.Linear(channels, channels, 1) - - def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: - new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) - # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) - new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) - return new_projection - - def forward(self, hidden_states): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - # transpose - query_states = self.transpose_for_scores(query_proj) - key_states = self.transpose_for_scores(key_proj) - value_states = self.transpose_for_scores(value_proj) - - # get scores - scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) - - attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) - attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) - - # compute attention output - hidden_states = torch.matmul(attention_probs, value_states) - - hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() - new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) - hidden_states = hidden_states.view(new_hidden_states_shape) - - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply - standard transformer action. Finally, reshape to image. - - Parameters: - in_channels (:obj:`int`): The number of channels in the input and output. - n_heads (:obj:`int`): The number of heads to use for multi-head attention. - d_head (:obj:`int`): The number of channels in each head. - depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. - context_dim (:obj:`int`, *optional*): The number of context dimensions to use. - """ - - def __init__( - self, - in_channels: int, - n_heads: int, - d_head: int, - depth: int = 1, - dropout: float = 0.0, - context_dim: Optional[int] = None, - ): - super().__init__() - self.n_heads = n_heads - self.d_head = d_head - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) - for d in range(depth) - ] - ) - - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - - def _set_attention_slice(self, slice_size): - for block in self.transformer_blocks: - block._set_attention_slice(slice_size) - - def forward(self, x, context=None): - # note: if no context is given, cross-attention defaults to self-attention - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - x = self.proj_in(x) - x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) - for block in self.transformer_blocks: - x = block(x, context=context) - x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) - x = self.proj_out(x) - return x + x_in - - -class BasicTransformerBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (:obj:`int`): The number of channels in the input and output. - n_heads (:obj:`int`): The number of heads to use for multi-head attention. - d_head (:obj:`int`): The number of channels in each head. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. - gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. - checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. - """ - - def __init__( - self, - dim: int, - n_heads: int, - d_head: int, - dropout=0.0, - context_dim: Optional[int] = None, - gated_ff: bool = True, - checkpoint: bool = True, - ): - super().__init__() - self.attn1 = CrossAttention( - query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention( - query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - - def _set_attention_slice(self, slice_size): - self.attn1._slice_size = slice_size - self.attn2._slice_size = slice_size - - def forward(self, x, context=None): - x = x.contiguous() if x.device.type == "mps" else x - x = self.attn1(self.norm1(x)) + x - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x - return x - - -class CrossAttention(nn.Module): - r""" - A cross attention layer. - - Parameters: - query_dim (:obj:`int`): The number of channels in the query. - context_dim (:obj:`int`, *optional*): - The number of channels in the context. If not given, defaults to `query_dim`. - heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. - dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - """ - - def __init__( - self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0 - ): - super().__init__() - inner_dim = dim_head * heads - context_dim = context_dim if context_dim is not None else query_dim - - self.scale = dim_head**-0.5 - self.heads = heads - # for slice_size > 0 the attention score computation - # is split across the batch axis to save memory - # You can set slice_size with `set_attention_slice` - self._slice_size = None - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def forward(self, x, context=None, mask=None): - batch_size, sequence_length, dim = x.shape - - q = self.to_q(x) - context = context if context is not None else x - k = self.to_k(context) - v = self.to_v(context) - - q = self.reshape_heads_to_batch_dim(q) - k = self.reshape_heads_to_batch_dim(k) - v = self.reshape_heads_to_batch_dim(v) - - # TODO(PVP) - mask is currently never used. Remember to re-implement when used - - # attention, what we cannot get enough of - hidden_states = self._attention(q, k, v, sequence_length, dim) - - return self.to_out(hidden_states) - - def _attention(self, query, key, value, sequence_length, dim): - batch_size_attention = query.shape[0] - hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype - ) - slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] - for i in range(hidden_states.shape[0] // slice_size): - start_idx = i * slice_size - end_idx = (i + 1) * slice_size - attn_slice = ( - torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale - ) - attn_slice = attn_slice.softmax(dim=-1) - attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - - -class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (:obj:`int`): The number of channels in the input. - dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - """ - - def __init__( - self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0 - ): - super().__init__() - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - project_in = GEGLU(dim, inner_dim) - - self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) - - def forward(self, x): - return self.net(x) - - -# feedforward -class GEGLU(nn.Module): - r""" - A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. - - Parameters: - dim_in (:obj:`int`): The number of channels in the input. - dim_out (:obj:`int`): The number of channels in the output. - """ - - def __init__(self, dim_in: int, dim_out: int): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) diff --git a/local_diffusers/models/embeddings.py b/local_diffusers/models/embeddings.py deleted file mode 100644 index 86ac074c1..000000000 --- a/local_diffusers/models/embeddings.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -import numpy as np -import torch -from torch import nn - - -def get_timestep_embedding( - timesteps: torch.Tensor, - embedding_dim: int, - flip_sin_to_cos: bool = False, - downscale_freq_shift: float = 1, - scale: float = 1, - max_period: int = 10000, -): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the - embeddings. :return: an [N x dim] Tensor of positional embeddings. - """ - assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" - - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32) - exponent = exponent / (half_dim - downscale_freq_shift) - - emb = torch.exp(exponent).to(device=timesteps.device) - emb = timesteps[:, None].float() * emb[None, :] - - # scale embeddings - emb = scale * emb - - # concat sine and cosine embeddings - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) - - # flip sine and cosine embeddings - if flip_sin_to_cos: - emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) - - # zero pad - if embedding_dim % 2 == 1: - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -class TimestepEmbedding(nn.Module): - def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): - super().__init__() - - self.linear_1 = nn.Linear(channel, time_embed_dim) - self.act = None - if act_fn == "silu": - self.act = nn.SiLU() - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) - - def forward(self, sample): - sample = self.linear_1(sample) - - if self.act is not None: - sample = self.act(sample) - - sample = self.linear_2(sample) - return sample - - -class Timesteps(nn.Module): - def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): - super().__init__() - self.num_channels = num_channels - self.flip_sin_to_cos = flip_sin_to_cos - self.downscale_freq_shift = downscale_freq_shift - - def forward(self, timesteps): - t_emb = get_timestep_embedding( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - ) - return t_emb - - -class GaussianFourierProjection(nn.Module): - """Gaussian Fourier embeddings for noise levels.""" - - def __init__(self, embedding_size: int = 256, scale: float = 1.0): - super().__init__() - self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) - - # to delete later - self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) - - self.weight = self.W - - def forward(self, x): - x = torch.log(x) - x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi - out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) - return out diff --git a/local_diffusers/models/resnet.py b/local_diffusers/models/resnet.py deleted file mode 100644 index 27fae24f7..000000000 --- a/local_diffusers/models/resnet.py +++ /dev/null @@ -1,483 +0,0 @@ -from functools import partial - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Upsample2D(nn.Module): - """ - An upsampling layer with an optional convolution. - - :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is - applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - - conv = None - if use_conv_transpose: - conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) - elif use_conv: - conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if name == "conv": - self.conv = conv - else: - self.Conv2d_0 = conv - - def forward(self, x): - assert x.shape[1] == self.channels - if self.use_conv_transpose: - return self.conv(x) - - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if self.use_conv: - if self.name == "conv": - x = self.conv(x) - else: - x = self.Conv2d_0(x) - - return x - - -class Downsample2D(nn.Module): - """ - A downsampling layer with an optional convolution. - - :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is - applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - - if use_conv: - conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) - else: - assert self.channels == self.out_channels - conv = nn.AvgPool2d(kernel_size=stride, stride=stride) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if name == "conv": - self.Conv2d_0 = conv - self.conv = conv - elif name == "Conv2d_0": - self.conv = conv - else: - self.conv = conv - - def forward(self, x): - assert x.shape[1] == self.channels - if self.use_conv and self.padding == 0: - pad = (0, 1, 0, 1) - x = F.pad(x, pad, mode="constant", value=0) - - assert x.shape[1] == self.channels - x = self.conv(x) - - return x - - -class FirUpsample2D(nn.Module): - def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) - self.use_conv = use_conv - self.fir_kernel = fir_kernel - self.out_channels = out_channels - - def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): - """Fused `upsample_2d()` followed by `Conv2d()`. - - Args: - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary: - order. - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, - C]`. - weight: Weight tensor of the shape `[filterH, filterW, inChannels, - outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as - `x`. - """ - - assert isinstance(factor, int) and factor >= 1 - - # Setup filter kernel. - if kernel is None: - kernel = [1] * factor - - # setup kernel - kernel = np.asarray(kernel, dtype=np.float32) - if kernel.ndim == 1: - kernel = np.outer(kernel, kernel) - kernel /= np.sum(kernel) - - kernel = kernel * (gain * (factor**2)) - - if self.use_conv: - convH = weight.shape[2] - convW = weight.shape[3] - inC = weight.shape[1] - - p = (kernel.shape[0] - factor) - (convW - 1) - - stride = (factor, factor) - # Determine data dimensions. - stride = [1, 1, factor, factor] - output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW) - output_padding = ( - output_shape[0] - (x.shape[2] - 1) * stride[0] - convH, - output_shape[1] - (x.shape[3] - 1) * stride[1] - convW, - ) - assert output_padding[0] >= 0 and output_padding[1] >= 0 - inC = weight.shape[1] - num_groups = x.shape[1] // inC - - # Transpose weights. - weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) - weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) - weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) - - x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0) - - x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) - else: - p = kernel.shape[0] - factor - x = upfirdn2d_native( - x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) - ) - - return x - - def forward(self, x): - if self.use_conv: - height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel) - height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2) - - return height - - -class FirDownsample2D(nn.Module): - def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) - self.fir_kernel = fir_kernel - self.use_conv = use_conv - self.out_channels = out_channels - - def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): - """Fused `Conv2d()` followed by `downsample_2d()`. - - Args: - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary: - order. - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH, - filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // - numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * - factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: - Scaling factor for signal magnitude (default: 1.0). - - Returns: - Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same - datatype as `x`. - """ - - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - # setup kernel - kernel = np.asarray(kernel, dtype=np.float32) - if kernel.ndim == 1: - kernel = np.outer(kernel, kernel) - kernel /= np.sum(kernel) - - kernel = kernel * gain - - if self.use_conv: - _, _, convH, convW = weight.shape - p = (kernel.shape[0] - factor) + (convW - 1) - s = [factor, factor] - x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2)) - x = F.conv2d(x, weight, stride=s, padding=0) - else: - p = kernel.shape[0] - factor - x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) - - return x - - def forward(self, x): - if self.use_conv: - x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) - x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2) - - return x - - -class ResnetBlock2D(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout=0.0, - temb_channels=512, - groups=32, - groups_out=None, - pre_norm=True, - eps=1e-6, - non_linearity="swish", - time_embedding_norm="default", - kernel=None, - output_scale_factor=1.0, - use_nin_shortcut=None, - up=False, - down=False, - ): - super().__init__() - self.pre_norm = pre_norm - self.pre_norm = True - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.time_embedding_norm = time_embedding_norm - self.up = up - self.down = down - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - if temb_channels is not None: - self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) - else: - self.time_emb_proj = None - - self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - - self.upsample = self.downsample = None - if self.up: - if kernel == "fir": - fir_kernel = (1, 3, 3, 1) - self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) - elif kernel == "sde_vp": - self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") - else: - self.upsample = Upsample2D(in_channels, use_conv=False) - elif self.down: - if kernel == "fir": - fir_kernel = (1, 3, 3, 1) - self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) - elif kernel == "sde_vp": - self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) - else: - self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") - - self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut - - self.conv_shortcut = None - if self.use_nin_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x, temb): - hidden_states = x - - # make sure hidden states is in float32 - # when running in half-precision - hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype) - hidden_states = self.nonlinearity(hidden_states) - - if self.upsample is not None: - x = self.upsample(x) - hidden_states = self.upsample(hidden_states) - elif self.downsample is not None: - x = self.downsample(x) - hidden_states = self.downsample(hidden_states) - - hidden_states = self.conv1(hidden_states) - - if temb is not None: - temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] - hidden_states = hidden_states + temb - - # make sure hidden states is in float32 - # when running in half-precision - hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype) - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) - - if self.conv_shortcut is not None: - x = self.conv_shortcut(x) - - out = (x + hidden_states) / self.output_scale_factor - - return out - - -class Mish(torch.nn.Module): - def forward(self, x): - return x * torch.tanh(torch.nn.functional.softplus(x)) - - -def upsample_2d(x, kernel=None, factor=2, gain=1): - r"""Upsample2D a batch of 2D images with the given filter. - - Args: - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given - filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified - `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a: - multiple of the upsampling factor. - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, - C]`. - k: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - Tensor of the shape `[N, C, H * factor, W * factor]` - """ - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - kernel = np.asarray(kernel, dtype=np.float32) - if kernel.ndim == 1: - kernel = np.outer(kernel, kernel) - kernel /= np.sum(kernel) - - kernel = kernel * (gain * (factor**2)) - p = kernel.shape[0] - factor - return upfirdn2d_native( - x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) - ) - - -def downsample_2d(x, kernel=None, factor=2, gain=1): - r"""Downsample2D a batch of 2D images with the given filter. - - Args: - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the - given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the - specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its - shape is a multiple of the downsampling factor. - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, - C]`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to average pooling. - factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - Tensor of the shape `[N, C, H // factor, W // factor]` - """ - - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - kernel = np.asarray(kernel, dtype=np.float32) - if kernel.ndim == 1: - kernel = np.outer(kernel, kernel) - kernel /= np.sum(kernel) - - kernel = kernel * gain - p = kernel.shape[0] - factor - return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) - - -def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): - up_x = up_y = up - down_x = down_y = down - pad_x0 = pad_y0 = pad[0] - pad_x1 = pad_y1 = pad[1] - - _, channel, in_h, in_w = input.shape - input = input.reshape(-1, in_h, in_w, 1) - - _, in_h, in_w, minor = input.shape - kernel_h, kernel_w = kernel.shape - - out = input.view(-1, in_h, 1, in_w, 1, minor) - - # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535 - if input.device.type == "mps": - out = out.to("cpu") - out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) - out = out.view(-1, in_h * up_y, in_w * up_x, minor) - - out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) - out = out.to(input.device) # Move back to mps if necessary - out = out[ - :, - max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), - max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), - :, - ] - - out = out.permute(0, 3, 1, 2) - out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) - w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) - out = F.conv2d(out, w) - out = out.reshape( - -1, - minor, - in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, - in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, - ) - out = out.permute(0, 2, 3, 1) - out = out[:, ::down_y, ::down_x, :] - - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 - - return out.view(-1, channel, out_h, out_w) diff --git a/local_diffusers/models/unet_2d.py b/local_diffusers/models/unet_2d.py deleted file mode 100644 index c3ab621a2..000000000 --- a/local_diffusers/models/unet_2d.py +++ /dev/null @@ -1,246 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin -from ..utils import BaseOutput -from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block - - -@dataclass -class UNet2DOutput(BaseOutput): - """ - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Hidden states output. Output of last layer of model. - """ - - sample: torch.FloatTensor - - -class UNet2DModel(ModelMixin, ConfigMixin): - r""" - UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): - Input sample size. - in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. - out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. - freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. - flip_sin_to_cos (`bool`, *optional*, defaults to : - obj:`False`): Whether to flip sin to cos for fourier time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block - types. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(224, 448, 672, 896)`): Tuple of block output channels. - layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. - mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. - downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. - norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization. - norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization. - """ - - @register_to_config - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 3, - out_channels: int = 3, - center_input_sample: bool = False, - time_embedding_type: str = "positional", - freq_shift: int = 0, - flip_sin_to_cos: bool = True, - down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), - up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), - block_out_channels: Tuple[int] = (224, 448, 672, 896), - layers_per_block: int = 2, - mid_block_scale_factor: float = 1, - downsample_padding: int = 1, - act_fn: str = "silu", - attention_head_dim: int = 8, - norm_num_groups: int = 32, - norm_eps: float = 1e-5, - ): - super().__init__() - - self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 - - # input - self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) - - # time - if time_embedding_type == "fourier": - self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) - timestep_input_dim = 2 * block_out_channels[0] - elif time_embedding_type == "positional": - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - - self.down_blocks = nn.ModuleList([]) - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - attn_num_head_channels=attention_head_dim, - downsample_padding=downsample_padding, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift="default", - attn_num_head_channels=attention_head_dim, - resnet_groups=norm_num_groups, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - attn_num_head_channels=attention_head_dim, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - return_dict: bool = True, - ) -> Union[UNet2DOutput, Tuple]: - """r - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. - - Returns: - [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, - otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - """ - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) - - t_emb = self.time_proj(timesteps) - emb = self.time_embedding(t_emb) - - # 2. pre-process - skip_sample = sample - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "skip_conv"): - sample, res_samples, skip_sample = downsample_block( - hidden_states=sample, temb=emb, skip_sample=skip_sample - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - sample = self.mid_block(sample, emb) - - # 5. up - skip_sample = None - for upsample_block in self.up_blocks: - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - if hasattr(upsample_block, "skip_conv"): - sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) - else: - sample = upsample_block(sample, res_samples, emb) - - # 6. post-process - # make sure hidden states is in float32 - # when running in half-precision - sample = self.conv_norm_out(sample.float()).type(sample.dtype) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if skip_sample is not None: - sample += skip_sample - - if self.config.time_embedding_type == "fourier": - timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) - sample = sample / timesteps - - if not return_dict: - return (sample,) - - return UNet2DOutput(sample=sample) diff --git a/local_diffusers/models/unet_2d_condition.py b/local_diffusers/models/unet_2d_condition.py deleted file mode 100644 index 92caaca92..000000000 --- a/local_diffusers/models/unet_2d_condition.py +++ /dev/null @@ -1,270 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin -from ..utils import BaseOutput -from .embeddings import TimestepEmbedding, Timesteps -from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block - - -@dataclass -class UNet2DConditionOutput(BaseOutput): - """ - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - - sample: torch.FloatTensor - - -class UNet2DConditionModel(ModelMixin, ConfigMixin): - r""" - UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep - and returns sample shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - sample_size (`int`, *optional*): The size of the input sample. - in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `False`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): - The tuple of upsample blocks to use. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - """ - - @register_to_config - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 4, - out_channels: int = 4, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: int = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, - attention_head_dim: int = 8, - ): - super().__init__() - - self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 - - # input - self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) - - # time - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - - self.down_blocks = nn.ModuleList([]) - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, - downsample_padding=downsample_padding, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlock2DCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift="default", - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, - resnet_groups=norm_num_groups, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) - - def set_attention_slice(self, slice_size): - if slice_size is not None and self.config.attention_head_dim % slice_size != 0: - raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.config.attention_head_dim}" - ) - if slice_size is not None and slice_size > self.config.attention_head_dim: - raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.config.attention_head_dim}" - ) - - for block in self.down_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_attention_slice(slice_size) - - self.mid_block.set_attention_slice(slice_size) - - for block in self.up_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_attention_slice(slice_size) - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - return_dict: bool = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - """r - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - - Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps.to(dtype=torch.float32) - timesteps = timesteps[None].to(device=sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - emb = self.time_embedding(t_emb) - - # 2. pre-process - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: - sample, res_samples = downsample_block( - hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) - - # 5. up - for upsample_block in self.up_blocks: - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - ) - else: - sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples) - - # 6. post-process - # make sure hidden states is in float32 - # when running in half-precision - sample = self.conv_norm_out(sample.float()).type(sample.dtype) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) diff --git a/local_diffusers/models/unet_blocks.py b/local_diffusers/models/unet_blocks.py deleted file mode 100644 index 9e0621653..000000000 --- a/local_diffusers/models/unet_blocks.py +++ /dev/null @@ -1,1481 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -import numpy as np - -# limitations under the License. -import torch -from torch import nn - -from .attention import AttentionBlock, SpatialTransformer -from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D - - -def get_down_block( - down_block_type, - num_layers, - in_channels, - out_channels, - temb_channels, - add_downsample, - resnet_eps, - resnet_act_fn, - attn_num_head_channels, - cross_attention_dim=None, - downsample_padding=None, -): - down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type - if down_block_type == "DownBlock2D": - return DownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - ) - elif down_block_type == "AttnDownBlock2D": - return AttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - ) - elif down_block_type == "CrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") - return CrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - ) - elif down_block_type == "SkipDownBlock2D": - return SkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - ) - elif down_block_type == "AttnSkipDownBlock2D": - return AttnSkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - ) - elif down_block_type == "DownEncoderBlock2D": - return DownEncoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - ) - - -def get_up_block( - up_block_type, - num_layers, - in_channels, - out_channels, - prev_output_channel, - temb_channels, - add_upsample, - resnet_eps, - resnet_act_fn, - attn_num_head_channels, - cross_attention_dim=None, -): - up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type - if up_block_type == "UpBlock2D": - return UpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ) - elif up_block_type == "CrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") - return CrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - ) - elif up_block_type == "AttnUpBlock2D": - return AttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attn_num_head_channels=attn_num_head_channels, - ) - elif up_block_type == "SkipUpBlock2D": - return SkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ) - elif up_block_type == "AttnSkipUpBlock2D": - return AttnSkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attn_num_head_channels=attn_num_head_channels, - ) - elif up_block_type == "UpDecoderBlock2D": - return UpDecoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ) - raise ValueError(f"{up_block_type} does not exist.") - - -class UNetMidBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - attention_type="default", - output_scale_factor=1.0, - **kwargs, - ): - super().__init__() - - self.attention_type = attention_type - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - attentions.append( - AttentionBlock( - in_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - num_groups=resnet_groups, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None, encoder_states=None): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if self.attention_type == "default": - hidden_states = attn(hidden_states) - else: - hidden_states = attn(hidden_states, encoder_states) - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class UNetMidBlock2DCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - attention_type="default", - output_scale_factor=1.0, - cross_attention_dim=1280, - **kwargs, - ): - super().__init__() - - self.attention_type = attention_type - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - attentions.append( - SpatialTransformer( - in_channels, - attn_num_head_channels, - in_channels // attn_num_head_channels, - depth=1, - context_dim=cross_attention_dim, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def set_attention_slice(self, slice_size): - if slice_size is not None and self.attn_num_head_channels % slice_size != 0: - raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" - ) - if slice_size is not None and slice_size > self.attn_num_head_channels: - raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states) - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class AttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - attention_type="default", - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - self.attention_type = attention_type - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states, temb=None): - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class CrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - attention_type="default", - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - self.attention_type = attention_type - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - SpatialTransformer( - out_channels, - attn_num_head_channels, - out_channels // attn_num_head_channels, - depth=1, - context_dim=cross_attention_dim, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - def set_attention_slice(self, slice_size): - if slice_size is not None and self.attn_num_head_channels % slice_size != 0: - raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" - ) - if slice_size is not None and slice_size > self.attn_num_head_channels: - raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=encoder_hidden_states) - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class DownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states, temb=None): - output_states = () - - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class DownEncoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states): - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states - - -class AttnDownEncoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - num_groups=resnet_groups, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states): - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None) - hidden_states = attn(hidden_states) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states - - -class AttnSkipDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - attention_type="default", - output_scale_factor=np.sqrt(2.0), - downsample_padding=1, - add_downsample=True, - ): - super().__init__() - self.attentions = nn.ModuleList([]) - self.resnets = nn.ModuleList([]) - - self.attention_type = attention_type - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - self.resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(in_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - ) - ) - - if add_downsample: - self.resnet_down = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_nin_shortcut=True, - down=True, - kernel="fir", - ) - self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) - self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) - else: - self.resnet_down = None - self.downsamplers = None - self.skip_conv = None - - def forward(self, hidden_states, temb=None, skip_sample=None): - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - output_states += (hidden_states,) - - if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb) - for downsampler in self.downsamplers: - skip_sample = downsampler(skip_sample) - - hidden_states = self.skip_conv(skip_sample) + hidden_states - - output_states += (hidden_states,) - - return hidden_states, output_states, skip_sample - - -class SkipDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - output_scale_factor=np.sqrt(2.0), - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - self.resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(in_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - if add_downsample: - self.resnet_down = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_nin_shortcut=True, - down=True, - kernel="fir", - ) - self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) - self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) - else: - self.resnet_down = None - self.downsamplers = None - self.skip_conv = None - - def forward(self, hidden_states, temb=None, skip_sample=None): - output_states = () - - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) - - if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb) - for downsampler in self.downsamplers: - skip_sample = downsampler(skip_sample) - - hidden_states = self.skip_conv(skip_sample) + hidden_states - - output_states += (hidden_states,) - - return hidden_states, output_states, skip_sample - - -class AttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attention_type="default", - attn_num_head_channels=1, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - self.attention_type = attention_type - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): - for resnet, attn in zip(self.resnets, self.attentions): - - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class CrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - attention_type="default", - output_scale_factor=1.0, - downsample_padding=1, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - self.attention_type = attention_type - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - SpatialTransformer( - out_channels, - attn_num_head_channels, - out_channels // attn_num_head_channels, - depth=1, - context_dim=cross_attention_dim, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - def set_attention_slice(self, slice_size): - if slice_size is not None and self.attn_num_head_channels % slice_size != 0: - raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" - ) - if slice_size is not None and slice_size > self.attn_num_head_channels: - raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None): - for resnet, attn in zip(self.resnets, self.attentions): - - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=encoder_hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class UpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): - for resnet in self.resnets: - - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class UpDecoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - def forward(self, hidden_states): - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class AttnUpDecoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - num_groups=resnet_groups, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - def forward(self, hidden_states): - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None) - hidden_states = attn(hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class AttnSkipUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - attention_type="default", - output_scale_factor=np.sqrt(2.0), - upsample_padding=1, - add_upsample=True, - ): - super().__init__() - self.attentions = nn.ModuleList([]) - self.resnets = nn.ModuleList([]) - - self.attention_type = attention_type - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - self.resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(resnet_in_channels + res_skip_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - ) - ) - - self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) - if add_upsample: - self.resnet_up = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_nin_shortcut=True, - up=True, - kernel="fir", - ) - self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - self.skip_norm = torch.nn.GroupNorm( - num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True - ) - self.act = nn.SiLU() - else: - self.resnet_up = None - self.skip_conv = None - self.skip_norm = None - self.act = None - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - hidden_states = self.attentions[0](hidden_states) - - if skip_sample is not None: - skip_sample = self.upsampler(skip_sample) - else: - skip_sample = 0 - - if self.resnet_up is not None: - skip_sample_states = self.skip_norm(hidden_states) - skip_sample_states = self.act(skip_sample_states) - skip_sample_states = self.skip_conv(skip_sample_states) - - skip_sample = skip_sample + skip_sample_states - - hidden_states = self.resnet_up(hidden_states, temb) - - return hidden_states, skip_sample - - -class SkipUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - output_scale_factor=np.sqrt(2.0), - add_upsample=True, - upsample_padding=1, - ): - super().__init__() - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - self.resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min((resnet_in_channels + res_skip_channels) // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) - if add_upsample: - self.resnet_up = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_nin_shortcut=True, - up=True, - kernel="fir", - ) - self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - self.skip_norm = torch.nn.GroupNorm( - num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True - ) - self.act = nn.SiLU() - else: - self.resnet_up = None - self.skip_conv = None - self.skip_norm = None - self.act = None - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - if skip_sample is not None: - skip_sample = self.upsampler(skip_sample) - else: - skip_sample = 0 - - if self.resnet_up is not None: - skip_sample_states = self.skip_norm(hidden_states) - skip_sample_states = self.act(skip_sample_states) - skip_sample_states = self.skip_conv(skip_sample_states) - - skip_sample = skip_sample + skip_sample_states - - hidden_states = self.resnet_up(hidden_states, temb) - - return hidden_states, skip_sample diff --git a/local_diffusers/models/vae.py b/local_diffusers/models/vae.py deleted file mode 100644 index 82748cb5b..000000000 --- a/local_diffusers/models/vae.py +++ /dev/null @@ -1,581 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin -from ..utils import BaseOutput -from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block - - -@dataclass -class DecoderOutput(BaseOutput): - """ - Output of decoding method. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Decoded output sample of the model. Output of the last layer of the model. - """ - - sample: torch.FloatTensor - - -@dataclass -class VQEncoderOutput(BaseOutput): - """ - Output of VQModel encoding method. - - Args: - latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Encoded output sample of the model. Output of the last layer of the model. - """ - - latents: torch.FloatTensor - - -@dataclass -class AutoencoderKLOutput(BaseOutput): - """ - Output of AutoencoderKL encoding method. - - Args: - latent_dist (`DiagonalGaussianDistribution`): - Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. - `DiagonalGaussianDistribution` allows for sampling latents from the distribution. - """ - - latent_dist: "DiagonalGaussianDistribution" - - -class Encoder(nn.Module): - def __init__( - self, - in_channels=3, - out_channels=3, - down_block_types=("DownEncoderBlock2D",), - block_out_channels=(64,), - layers_per_block=2, - act_fn="silu", - double_z=True, - ): - super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) - - self.mid_block = None - self.down_blocks = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=self.layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - add_downsample=not is_final_block, - resnet_eps=1e-6, - downsample_padding=0, - resnet_act_fn=act_fn, - attn_num_head_channels=None, - temb_channels=None, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", - attn_num_head_channels=None, - resnet_groups=32, - temb_channels=None, - ) - - # out - num_groups_out = 32 - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6) - self.conv_act = nn.SiLU() - - conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) - - def forward(self, x): - sample = x - sample = self.conv_in(sample) - - # down - for down_block in self.down_blocks: - sample = down_block(sample) - - # middle - sample = self.mid_block(sample) - - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - return sample - - -class Decoder(nn.Module): - def __init__( - self, - in_channels=3, - out_channels=3, - up_block_types=("UpDecoderBlock2D",), - block_out_channels=(64,), - layers_per_block=2, - act_fn="silu", - ): - super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) - - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - # mid - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", - attn_num_head_channels=None, - resnet_groups=32, - temb_channels=None, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = get_up_block( - up_block_type, - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - prev_output_channel=None, - add_upsample=not is_final_block, - resnet_eps=1e-6, - resnet_act_fn=act_fn, - attn_num_head_channels=None, - temb_channels=None, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - num_groups_out = 32 - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) - - def forward(self, z): - sample = z - sample = self.conv_in(sample) - - # middle - sample = self.mid_block(sample) - - # up - for up_block in self.up_blocks: - sample = up_block(sample) - - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - return sample - - -class VectorQuantizer(nn.Module): - """ - Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix - multiplications and allows for post-hoc remapping of indices. - """ - - # NOTE: due to a bug the beta term was applied to the wrong term. for - # backwards compatibility we use the buggy version by default, but you can - # specify legacy=False to fix it. - def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): - super().__init__() - self.n_e = n_e - self.e_dim = e_dim - self.beta = beta - self.legacy = legacy - - self.embedding = nn.Embedding(self.n_e, self.e_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - print( - f"Remapping {self.n_e} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices." - ) - else: - self.re_embed = n_e - - self.sane_index_shape = sane_index_shape - - def remap_to_used(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - match = (inds[:, :, None] == used[None, None, ...]).long() - new = match.argmax(-1) - unknown = match.sum(2) < 1 - if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds >= self.used.shape[0]] = 0 # simply set to zero - back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) - return back.reshape(ishape) - - def forward(self, z): - # reshape z -> (batch, height, width, channel) and flatten - z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.e_dim) - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - - d = ( - torch.sum(z_flattened**2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight**2, dim=1) - - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t()) - ) - - min_encoding_indices = torch.argmin(d, dim=1) - z_q = self.embedding(min_encoding_indices).view(z.shape) - perplexity = None - min_encodings = None - - # compute loss for embedding - if not self.legacy: - loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) - else: - loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis - min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten - - if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) - - return z_q, loss, (perplexity, min_encodings, min_encoding_indices) - - def get_codebook_entry(self, indices, shape): - # shape specifying (batch, height, width, channel) - if self.remap is not None: - indices = indices.reshape(shape[0], -1) # add batch axis - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again - - # get quantized latent vectors - z_q = self.embedding(indices) - - if shape is not None: - z_q = z_q.view(shape) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) - - def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: - device = self.parameters.device - sample_device = "cpu" if device.type == "mps" else device - sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device) - x = self.mean + self.std * sample - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - - def nll(self, sample, dims=[1, 2, 3]): - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) - - def mode(self): - return self.mean - - -class VQModel(ModelMixin, ConfigMixin): - r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray - Kavukcuoglu. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - in_channels (int, *optional*, defaults to 3): Number of channels in the input image. - out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(64,)`): Tuple of block output channels. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. - sample_size (`int`, *optional*, defaults to `32`): TODO - num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. - """ - - @register_to_config - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock2D",), - up_block_types: Tuple[str] = ("UpDecoderBlock2D",), - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 3, - sample_size: int = 32, - num_vq_embeddings: int = 256, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = Encoder( - in_channels=in_channels, - out_channels=latent_channels, - down_block_types=down_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - double_z=False, - ) - - self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) - self.quantize = VectorQuantizer( - num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False - ) - self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) - - # pass init params to Decoder - self.decoder = Decoder( - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - ) - - def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: - h = self.encoder(x) - h = self.quant_conv(h) - - if not return_dict: - return (h,) - - return VQEncoderOutput(latents=h) - - def decode( - self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True - ) -> Union[DecoderOutput, torch.FloatTensor]: - # also go through quantization layer - if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) - else: - quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - r""" - Args: - sample (`torch.FloatTensor`): Input sample. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - """ - x = sample - h = self.encode(x).latents - dec = self.decode(h).sample - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - -class AutoencoderKL(ModelMixin, ConfigMixin): - r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma - and Max Welling. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - in_channels (int, *optional*, defaults to 3): Number of channels in the input image. - out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(64,)`): Tuple of block output channels. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. - sample_size (`int`, *optional*, defaults to `32`): TODO - """ - - @register_to_config - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock2D",), - up_block_types: Tuple[str] = ("UpDecoderBlock2D",), - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 4, - sample_size: int = 32, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = Encoder( - in_channels=in_channels, - out_channels=latent_channels, - down_block_types=down_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - double_z=True, - ) - - # pass init params to Decoder - self.decoder = Decoder( - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - ) - - self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) - self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) - - def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) - - def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - z = self.post_quant_conv(z) - dec = self.decoder(z) - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - def forward( - self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True - ) -> Union[DecoderOutput, torch.FloatTensor]: - r""" - Args: - sample (`torch.FloatTensor`): Input sample. - sample_posterior (`bool`, *optional*, defaults to `False`): - Whether to sample from the posterior. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - """ - x = sample - posterior = self.encode(x).latent_dist - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() - dec = self.decode(z).sample - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) diff --git a/local_diffusers/onnx_utils.py b/local_diffusers/onnx_utils.py deleted file mode 100644 index e840565dd..000000000 --- a/local_diffusers/onnx_utils.py +++ /dev/null @@ -1,189 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os -import shutil -from pathlib import Path -from typing import Optional, Union - -import numpy as np - -from huggingface_hub import hf_hub_download - -from .utils import is_onnx_available, logging - - -if is_onnx_available(): - import onnxruntime as ort - - -ONNX_WEIGHTS_NAME = "model.onnx" - - -logger = logging.get_logger(__name__) - - -class OnnxRuntimeModel: - base_model_prefix = "onnx_model" - - def __init__(self, model=None, **kwargs): - logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") - self.model = model - self.model_save_dir = kwargs.get("model_save_dir", None) - self.latest_model_name = kwargs.get("latest_model_name", "model.onnx") - - def __call__(self, **kwargs): - inputs = {k: np.array(v) for k, v in kwargs.items()} - return self.model.run(None, inputs) - - @staticmethod - def load_model(path: Union[str, Path], provider=None): - """ - Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` - - Arguments: - path (`str` or `Path`): - Directory from which to load - provider(`str`, *optional*): - Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider` - """ - if provider is None: - logger.info("No onnxruntime provider specified, using CPUExecutionProvider") - provider = "CPUExecutionProvider" - - return ort.InferenceSession(path, providers=[provider]) - - def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): - """ - Save a model and its configuration file to a directory, so that it can be re-loaded using the - [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the - latest_model_name. - - Arguments: - save_directory (`str` or `Path`): - Directory where to save the model file. - file_name(`str`, *optional*): - Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the - model with a different name. - """ - model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME - - src_path = self.model_save_dir.joinpath(self.latest_model_name) - dst_path = Path(save_directory).joinpath(model_file_name) - if not src_path.samefile(dst_path): - shutil.copyfile(src_path, dst_path) - - def save_pretrained( - self, - save_directory: Union[str, os.PathLike], - **kwargs, - ): - """ - Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class - method.: - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - """ - if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - - os.makedirs(save_directory, exist_ok=True) - - # saving model weights/files - self._save_pretrained(save_directory, **kwargs) - - @classmethod - def _from_pretrained( - cls, - model_id: Union[str, Path], - use_auth_token: Optional[Union[bool, str, None]] = None, - revision: Optional[Union[str, None]] = None, - force_download: bool = False, - cache_dir: Optional[str] = None, - file_name: Optional[str] = None, - provider: Optional[str] = None, - **kwargs, - ): - """ - Load a model from a directory or the HF Hub. - - Arguments: - model_id (`str` or `Path`): - Directory from which to load - use_auth_token (`str` or `bool`): - Is needed to load models from a private or gated repository - revision (`str`): - Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id - cache_dir (`Union[str, Path]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - file_name(`str`): - Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load - different model files from the same repository or directory. - provider(`str`): - The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`. - kwargs (`Dict`, *optional*): - kwargs will be passed to the model during initialization - """ - model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME - # load model from local directory - if os.path.isdir(model_id): - model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider) - kwargs["model_save_dir"] = Path(model_id) - # load model from hub - else: - # download model - model_cache_path = hf_hub_download( - repo_id=model_id, - filename=model_file_name, - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - ) - kwargs["model_save_dir"] = Path(model_cache_path).parent - kwargs["latest_model_name"] = Path(model_cache_path).name - model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider) - return cls(model=model, **kwargs) - - @classmethod - def from_pretrained( - cls, - model_id: Union[str, Path], - force_download: bool = True, - use_auth_token: Optional[str] = None, - cache_dir: Optional[str] = None, - **model_kwargs, - ): - revision = None - if len(str(model_id).split("@")) == 2: - model_id, revision = model_id.split("@") - - return cls._from_pretrained( - model_id=model_id, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - use_auth_token=use_auth_token, - **model_kwargs, - ) diff --git a/local_diffusers/optimization.py b/local_diffusers/optimization.py deleted file mode 100644 index e7b836b4a..000000000 --- a/local_diffusers/optimization.py +++ /dev/null @@ -1,275 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch optimization for diffusion models.""" - -import math -from enum import Enum -from typing import Optional, Union - -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LambdaLR - -from .utils import logging - - -logger = logging.get_logger(__name__) - - -class SchedulerType(Enum): - LINEAR = "linear" - COSINE = "cosine" - COSINE_WITH_RESTARTS = "cosine_with_restarts" - POLYNOMIAL = "polynomial" - CONSTANT = "constant" - CONSTANT_WITH_WARMUP = "constant_with_warmup" - - -def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): - """ - Create a schedule with a constant learning rate, using the learning rate set in optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) - - -def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): - """ - Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate - increases linearly between 0 and the initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - - def lr_lambda(current_step: int): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1.0, num_warmup_steps)) - return 1.0 - - return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) - - -def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): - """ - Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after - a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - - def lr_lambda(current_step: int): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - return max( - 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) - ) - - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -def get_cosine_schedule_with_warmup( - optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 -): - """ - Create a schedule with a learning rate that decreases following the values of the cosine function between the - initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the - initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - num_cycles (`float`, *optional*, defaults to 0.5): - The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 - following a half-cosine). - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - - def lr_lambda(current_step): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) - return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) - - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -def get_cosine_with_hard_restarts_schedule_with_warmup( - optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 -): - """ - Create a schedule with a learning rate that decreases following the values of the cosine function between the - initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases - linearly between 0 and the initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - num_cycles (`int`, *optional*, defaults to 1): - The number of hard restarts to use. - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - - def lr_lambda(current_step): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) - if progress >= 1.0: - return 0.0 - return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) - - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -def get_polynomial_decay_schedule_with_warmup( - optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 -): - """ - Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the - optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the - initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - lr_end (`float`, *optional*, defaults to 1e-7): - The end LR. - power (`float`, *optional*, defaults to 1.0): - Power factor. - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT - implementation at - https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - - """ - - lr_init = optimizer.defaults["lr"] - if not (lr_init > lr_end): - raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") - - def lr_lambda(current_step: int): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - elif current_step > num_training_steps: - return lr_end / lr_init # as LambdaLR multiplies by lr_init - else: - lr_range = lr_init - lr_end - decay_steps = num_training_steps - num_warmup_steps - pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps - decay = lr_range * pct_remaining**power + lr_end - return decay / lr_init # as LambdaLR multiplies by lr_init - - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -TYPE_TO_SCHEDULER_FUNCTION = { - SchedulerType.LINEAR: get_linear_schedule_with_warmup, - SchedulerType.COSINE: get_cosine_schedule_with_warmup, - SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, - SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, - SchedulerType.CONSTANT: get_constant_schedule, - SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, -} - - -def get_scheduler( - name: Union[str, SchedulerType], - optimizer: Optimizer, - num_warmup_steps: Optional[int] = None, - num_training_steps: Optional[int] = None, -): - """ - Unified API to get any scheduler from its name. - - Args: - name (`str` or `SchedulerType`): - The name of the scheduler to use. - optimizer (`torch.optim.Optimizer`): - The optimizer that will be used during training. - num_warmup_steps (`int`, *optional*): - The number of warmup steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_training_steps (`int``, *optional*): - The number of training steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - """ - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] - if name == SchedulerType.CONSTANT: - return schedule_func(optimizer) - - # All other schedulers require `num_warmup_steps` - if num_warmup_steps is None: - raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") - - if name == SchedulerType.CONSTANT_WITH_WARMUP: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) - - # All other schedulers require `num_training_steps` - if num_training_steps is None: - raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") - - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) diff --git a/local_diffusers/pipeline_utils.py b/local_diffusers/pipeline_utils.py deleted file mode 100644 index 84ee9e20f..000000000 --- a/local_diffusers/pipeline_utils.py +++ /dev/null @@ -1,417 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -import inspect -import os -from dataclasses import dataclass -from typing import List, Optional, Union - -import numpy as np -import torch - -import diffusers -import PIL -from huggingface_hub import snapshot_download -from PIL import Image -from tqdm.auto import tqdm - -from .configuration_utils import ConfigMixin -from .utils import DIFFUSERS_CACHE, BaseOutput, logging - - -INDEX_FILE = "diffusion_pytorch_model.bin" - - -logger = logging.get_logger(__name__) - - -LOADABLE_CLASSES = { - "diffusers": { - "ModelMixin": ["save_pretrained", "from_pretrained"], - "SchedulerMixin": ["save_config", "from_config"], - "DiffusionPipeline": ["save_pretrained", "from_pretrained"], - "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], - }, - "transformers": { - "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], - "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], - "PreTrainedModel": ["save_pretrained", "from_pretrained"], - "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], - }, -} - -ALL_IMPORTABLE_CLASSES = {} -for library in LOADABLE_CLASSES: - ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) - - -@dataclass -class ImagePipelineOutput(BaseOutput): - """ - Output class for image pipelines. - - Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - """ - - images: Union[List[PIL.Image.Image], np.ndarray] - - -class DiffusionPipeline(ConfigMixin): - r""" - Base class for all models. - - [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines - and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to: - - - move all PyTorch modules to the device of your choice - - enabling/disabling the progress bar for the denoising iteration - - Class attributes: - - - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all - compenents of the diffusion pipeline. - """ - config_name = "model_index.json" - - def register_modules(self, **kwargs): - # import it here to avoid circular import - from diffusers import pipelines - - for name, module in kwargs.items(): - # retrive library - library = module.__module__.split(".")[0] - - # check if the module is a pipeline module - pipeline_dir = module.__module__.split(".")[-2] - path = module.__module__.split(".") - is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) - - # if library is not in LOADABLE_CLASSES, then it is a custom module. - # Or if it's a pipeline module, then the module is inside the pipeline - # folder so we set the library to module name. - if library not in LOADABLE_CLASSES or is_pipeline_module: - library = pipeline_dir - - # retrive class_name - class_name = module.__class__.__name__ - - register_dict = {name: (library, class_name)} - - # save model index config - self.register_to_config(**register_dict) - - # set models - setattr(self, name, module) - - def save_pretrained(self, save_directory: Union[str, os.PathLike]): - """ - Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to - a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading - method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - """ - self.save_config(save_directory) - - model_index_dict = dict(self.config) - model_index_dict.pop("_class_name") - model_index_dict.pop("_diffusers_version") - model_index_dict.pop("_module", None) - - for pipeline_component_name in model_index_dict.keys(): - sub_model = getattr(self, pipeline_component_name) - model_cls = sub_model.__class__ - - save_method_name = None - # search for the model's base class in LOADABLE_CLASSES - for library_name, library_classes in LOADABLE_CLASSES.items(): - library = importlib.import_module(library_name) - for base_class, save_load_methods in library_classes.items(): - class_candidate = getattr(library, base_class) - if issubclass(model_cls, class_candidate): - # if we found a suitable base class in LOADABLE_CLASSES then grab its save method - save_method_name = save_load_methods[0] - break - if save_method_name is not None: - break - - save_method = getattr(sub_model, save_method_name) - save_method(os.path.join(save_directory, pipeline_component_name)) - - def to(self, torch_device: Optional[Union[str, torch.device]] = None): - if torch_device is None: - return self - - module_names, _ = self.extract_init_dict(dict(self.config)) - for name in module_names.keys(): - module = getattr(self, name) - if isinstance(module, torch.nn.Module): - module.to(torch_device) - return self - - @property - def device(self) -> torch.device: - r""" - Returns: - `torch.device`: The torch device on which the pipeline is located. - """ - module_names, _ = self.extract_init_dict(dict(self.config)) - for name in module_names.keys(): - module = getattr(self, name) - if isinstance(module, torch.nn.Module): - return module.device - return torch.device("cpu") - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): - r""" - Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. - - The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. - - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on - https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like - `CompVis/ldm-text2im-large-256`. - - A path to a *directory* containing pipeline weights saved using - [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `huggingface-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. specify the folder name here. - - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the - speficic pipeline class. The overritten components are then directly passed to the pipelines `__init__` - method. See example below for more information. - - - - Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.* - `"CompVis/stable-diffusion-v1-4"` - - - - - - Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use - this method in a firewalled environment. - - - - Examples: - - ```py - >>> from diffusers import DiffusionPipeline - - >>> # Download pipeline from huggingface.co and cache. - >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") - - >>> # Download pipeline that requires an authorization token - >>> # For more information on access tokens, please refer to this section - >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) - >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) - - >>> # Download pipeline, but overwrite scheduler - >>> from diffusers import LMSDiscreteScheduler - - >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") - >>> pipeline = DiffusionPipeline.from_pretrained( - ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True - ... ) - ``` - """ - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) - provider = kwargs.pop("provider", None) - - # 1. Download the checkpoints and configs - # use snapshot download here to get it working from from_pretrained - if not os.path.isdir(pretrained_model_name_or_path): - cached_folder = snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - ) - else: - cached_folder = pretrained_model_name_or_path - - config_dict = cls.get_config_dict(cached_folder) - - # 2. Load the pipeline class, if using custom module then load it from the hub - # if we load from explicit class, let's use it - if cls != DiffusionPipeline: - pipeline_class = cls - else: - diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) - pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) - - # some modules can be passed directly to the init - # in this case they are already instantiated in `kwargs` - # extract them here - expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} - - init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) - - init_kwargs = {} - - # import it here to avoid circular import - from diffusers import pipelines - - # 3. Load each module in the pipeline - for name, (library_name, class_name) in init_dict.items(): - is_pipeline_module = hasattr(pipelines, library_name) - loaded_sub_model = None - - # if the model is in a pipeline module, then we load it from the pipeline - if name in passed_class_obj: - # 1. check that passed_class_obj has correct parent class - if not is_pipeline_module: - library = importlib.import_module(library_name) - class_obj = getattr(library, class_name) - importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} - - expected_class_obj = None - for class_name, class_candidate in class_candidates.items(): - if issubclass(class_obj, class_candidate): - expected_class_obj = class_candidate - - if not issubclass(passed_class_obj[name].__class__, expected_class_obj): - raise ValueError( - f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" - f" {expected_class_obj}" - ) - else: - logger.warn( - f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" - " has the correct type" - ) - - # set passed class object - loaded_sub_model = passed_class_obj[name] - elif is_pipeline_module: - pipeline_module = getattr(pipelines, library_name) - class_obj = getattr(pipeline_module, class_name) - importable_classes = ALL_IMPORTABLE_CLASSES - class_candidates = {c: class_obj for c in importable_classes.keys()} - else: - # else we just import it from the library. - library = importlib.import_module(library_name) - class_obj = getattr(library, class_name) - importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} - - if loaded_sub_model is None: - load_method_name = None - for class_name, class_candidate in class_candidates.items(): - if issubclass(class_obj, class_candidate): - load_method_name = importable_classes[class_name][1] - - load_method = getattr(class_obj, load_method_name) - - loading_kwargs = {} - if issubclass(class_obj, torch.nn.Module): - loading_kwargs["torch_dtype"] = torch_dtype - if issubclass(class_obj, diffusers.OnnxRuntimeModel): - loading_kwargs["provider"] = provider - - # check if the module is in a subdirectory - if os.path.isdir(os.path.join(cached_folder, name)): - loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) - else: - # else load from the root directory - loaded_sub_model = load_method(cached_folder, **loading_kwargs) - - init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - - # 4. Instantiate the pipeline - model = pipeline_class(**init_kwargs) - return model - - @staticmethod - def numpy_to_pil(images): - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - - def progress_bar(self, iterable): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - return tqdm(iterable, **self._progress_bar_config) - - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs diff --git a/local_diffusers/pipelines/__init__.py b/local_diffusers/pipelines/__init__.py deleted file mode 100644 index 3e2aeb4fb..000000000 --- a/local_diffusers/pipelines/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from ..utils import is_onnx_available, is_transformers_available -from .ddim import DDIMPipeline -from .ddpm import DDPMPipeline -from .latent_diffusion_uncond import LDMPipeline -from .pndm import PNDMPipeline -from .score_sde_ve import ScoreSdeVePipeline -from .stochastic_karras_ve import KarrasVePipeline - - -if is_transformers_available(): - from .latent_diffusion import LDMTextToImagePipeline - from .stable_diffusion import ( - StableDiffusionImg2ImgPipeline, - StableDiffusionInpaintPipeline, - StableDiffusionPipeline, - ) - -if is_transformers_available() and is_onnx_available(): - from .stable_diffusion import StableDiffusionOnnxPipeline diff --git a/local_diffusers/pipelines/ddim/__init__.py b/local_diffusers/pipelines/ddim/__init__.py deleted file mode 100644 index 8fd31868a..000000000 --- a/local_diffusers/pipelines/ddim/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# flake8: noqa -from .pipeline_ddim import DDIMPipeline diff --git a/local_diffusers/pipelines/ddim/pipeline_ddim.py b/local_diffusers/pipelines/ddim/pipeline_ddim.py deleted file mode 100644 index 33f6064db..000000000 --- a/local_diffusers/pipelines/ddim/pipeline_ddim.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -# limitations under the License. - - -import warnings -from typing import Optional, Tuple, Union - -import torch - -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -class DDIMPipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of - [`DDPMScheduler`], or [`DDIMScheduler`]. - """ - - def __init__(self, unet, scheduler): - super().__init__() - scheduler = scheduler.set_format("pt") - self.register_modules(unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - batch_size: int = 1, - generator: Optional[torch.Generator] = None, - eta: float = 0.0, - num_inference_steps: int = 50, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ) -> Union[ImagePipelineOutput, Tuple]: - r""" - Args: - batch_size (`int`, *optional*, defaults to 1): - The number of images to generate. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - eta (`float`, *optional*, defaults to 0.0): - The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. - """ - - if "torch_device" in kwargs: - device = kwargs.pop("torch_device") - warnings.warn( - "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." - " Consider using `pipe.to(torch_device)` instead." - ) - - # Set device as before (to be removed in 0.3.0) - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.to(device) - - # eta corresponds to η in paper and should be between [0, 1] - - # Sample gaussian noise to begin loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - ) - image = image.to(self.device) - - # set step values - self.scheduler.set_timesteps(num_inference_steps) - - for t in self.progress_bar(self.scheduler.timesteps): - # 1. predict noise model_output - model_output = self.unet(image, t).sample - - # 2. predict previous mean of image x_t-1 and add variance depending on eta - # do x_t -> x_t-1 - image = self.scheduler.step(model_output, t, image, eta).prev_sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/local_diffusers/pipelines/ddpm/__init__.py b/local_diffusers/pipelines/ddpm/__init__.py deleted file mode 100644 index 8889bdae1..000000000 --- a/local_diffusers/pipelines/ddpm/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# flake8: noqa -from .pipeline_ddpm import DDPMPipeline diff --git a/local_diffusers/pipelines/ddpm/pipeline_ddpm.py b/local_diffusers/pipelines/ddpm/pipeline_ddpm.py deleted file mode 100644 index 71103bbe4..000000000 --- a/local_diffusers/pipelines/ddpm/pipeline_ddpm.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -# limitations under the License. - - -import warnings -from typing import Optional, Tuple, Union - -import torch - -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -class DDPMPipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of - [`DDPMScheduler`], or [`DDIMScheduler`]. - """ - - def __init__(self, unet, scheduler): - super().__init__() - scheduler = scheduler.set_format("pt") - self.register_modules(unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - batch_size: int = 1, - generator: Optional[torch.Generator] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ) -> Union[ImagePipelineOutput, Tuple]: - r""" - Args: - batch_size (`int`, *optional*, defaults to 1): - The number of images to generate. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. - """ - if "torch_device" in kwargs: - device = kwargs.pop("torch_device") - warnings.warn( - "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." - " Consider using `pipe.to(torch_device)` instead." - ) - - # Set device as before (to be removed in 0.3.0) - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.to(device) - - # Sample gaussian noise to begin loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - ) - image = image.to(self.device) - - # set step values - self.scheduler.set_timesteps(1000) - - for t in self.progress_bar(self.scheduler.timesteps): - # 1. predict noise model_output - model_output = self.unet(image, t).sample - - # 2. compute previous image: x_t -> t_t-1 - image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/local_diffusers/pipelines/latent_diffusion/__init__.py b/local_diffusers/pipelines/latent_diffusion/__init__.py deleted file mode 100644 index c481b38cf..000000000 --- a/local_diffusers/pipelines/latent_diffusion/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# flake8: noqa -from ...utils import is_transformers_available - - -if is_transformers_available(): - from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline diff --git a/local_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/local_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py deleted file mode 100644 index b39840f24..000000000 --- a/local_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ /dev/null @@ -1,705 +0,0 @@ -import inspect -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.utils.checkpoint - -from transformers.activations import ACT2FN -from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_outputs import BaseModelOutput -from transformers.modeling_utils import PreTrainedModel -from transformers.tokenization_utils import PreTrainedTokenizer -from transformers.utils import logging - -from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler - - -class LDMTextToImagePipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - vqvae ([`VQModel`]): - Vector-quantized (VQ) Model to encode and decode images to and from latent representations. - bert ([`LDMBertModel`]): - Text-encoder model based on [BERT](ttps://huggingface.co/docs/transformers/model_doc/bert) architecture. - tokenizer (`transformers.BertTokenizer`): - Tokenizer of class - [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - """ - - def __init__( - self, - vqvae: Union[VQModel, AutoencoderKL], - bert: PreTrainedModel, - tokenizer: PreTrainedTokenizer, - unet: Union[UNet2DModel, UNet2DConditionModel], - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - ): - super().__init__() - scheduler = scheduler.set_format("pt") - self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - height: Optional[int] = 256, - width: Optional[int] = 256, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 1.0, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ) -> Union[Tuple, ImagePipelineOutput]: - r""" - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 256): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 256): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 1.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at - the, usually at the expense of lower image quality. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. - return_dict (`bool`, *optional*): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. - """ - if "torch_device" in kwargs: - device = kwargs.pop("torch_device") - warnings.warn( - "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." - " Consider using `pipe.to(torch_device)` instead." - ) - - # Set device as before (to be removed in 0.3.0) - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.to(device) - - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - # get unconditional embeddings for classifier free guidance - if guidance_scale != 1.0: - uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") - uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0] - - # get prompt text embeddings - text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") - text_embeddings = self.bert(text_input.input_ids.to(self.device))[0] - - latents = torch.randn( - (batch_size, self.unet.in_channels, height // 8, width // 8), - generator=generator, - ) - latents = latents.to(self.device) - - self.scheduler.set_timesteps(num_inference_steps) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - - extra_kwargs = {} - if accepts_eta: - extra_kwargs["eta"] = eta - - for t in self.progress_bar(self.scheduler.timesteps): - if guidance_scale == 1.0: - # guidance_scale of 1 means no guidance - latents_input = latents - context = text_embeddings - else: - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - latents_input = torch.cat([latents] * 2) - context = torch.cat([uncond_embeddings, text_embeddings]) - - # predict the noise residual - noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample - # perform guidance - if guidance_scale != 1.0: - noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample - - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - image = self.vqvae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) - - -################################################################################ -# Code for the text transformer model -################################################################################ -""" PyTorch LDMBERT model.""" - - -logger = logging.get_logger(__name__) - -LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "ldm-bert", - # See all LDMBert models at https://huggingface.co/models?filter=ldmbert -] - - -LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "ldm-bert": "https://huggingface.co/ldm-bert/resolve/main/config.json", -} - - -""" LDMBERT model configuration""" - - -class LDMBertConfig(PretrainedConfig): - model_type = "ldmbert" - keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} - - def __init__( - self, - vocab_size=30522, - max_position_embeddings=77, - encoder_layers=32, - encoder_ffn_dim=5120, - encoder_attention_heads=8, - head_dim=64, - encoder_layerdrop=0.0, - activation_function="gelu", - d_model=1280, - dropout=0.1, - attention_dropout=0.0, - activation_dropout=0.0, - init_std=0.02, - classifier_dropout=0.0, - scale_embedding=False, - use_cache=True, - pad_token_id=0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.d_model = d_model - self.encoder_ffn_dim = encoder_ffn_dim - self.encoder_layers = encoder_layers - self.encoder_attention_heads = encoder_attention_heads - self.head_dim = head_dim - self.dropout = dropout - self.attention_dropout = attention_dropout - self.activation_dropout = activation_dropout - self.activation_function = activation_function - self.init_std = init_std - self.encoder_layerdrop = encoder_layerdrop - self.classifier_dropout = classifier_dropout - self.use_cache = use_cache - self.num_hidden_layers = encoder_layers - self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True - - super().__init__(pad_token_id=pad_token_id, **kwargs) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert -class LDMBertAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - head_dim: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = False, - ): - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = head_dim - self.inner_dim = head_dim * num_heads - - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) - self.out_proj = nn.Linear(self.inner_dim, embed_dim) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class LDMBertEncoderLayer(nn.Module): - def __init__(self, config: LDMBertConfig): - super().__init__() - self.embed_dim = config.d_model - self.self_attn = LDMBertAttention( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - head_dim=config.head_dim, - dropout=config.attention_dropout, - ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: torch.FloatTensor, - layer_head_mask: torch.FloatTensor, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` - attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() - ): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert -class LDMBertPreTrainedModel(PreTrainedModel): - config_class = LDMBertConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] - - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (LDMBertEncoder,)): - module.gradient_checkpointing = value - - @property - def dummy_inputs(self): - pad_token = self.config.pad_token_id - input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) - dummy_inputs = { - "attention_mask": input_ids.ne(pad_token), - "input_ids": input_ids, - } - return dummy_inputs - - -class LDMBertEncoder(LDMBertPreTrainedModel): - """ - Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a - [`LDMBertEncoderLayer`]. - - Args: - config: LDMBertConfig - embed_tokens (nn.Embedding): output embedding - """ - - def __init__(self, config: LDMBertConfig): - super().__init__(config) - - self.dropout = config.dropout - - embed_dim = config.d_model - self.padding_idx = config.pad_token_id - self.max_source_positions = config.max_position_embeddings - - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) - self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim) - self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)]) - self.layer_norm = nn.LayerNorm(embed_dim) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: - r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - seq_len = input_shape[1] - if position_ids is None: - position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1)) - embed_pos = self.embed_positions(position_ids) - - hidden_states = inputs_embeds + embed_pos - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - if head_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}." - ) - - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - hidden_states = self.layer_norm(hidden_states) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - - -class LDMBertModel(LDMBertPreTrainedModel): - def __init__(self, config: LDMBertConfig): - super().__init__(config) - self.model = LDMBertEncoder(config) - self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) - - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - return outputs diff --git a/local_diffusers/pipelines/latent_diffusion_uncond/__init__.py b/local_diffusers/pipelines/latent_diffusion_uncond/__init__.py deleted file mode 100644 index 0826ca753..000000000 --- a/local_diffusers/pipelines/latent_diffusion_uncond/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# flake8: noqa -from .pipeline_latent_diffusion_uncond import LDMPipeline diff --git a/local_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/local_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py deleted file mode 100644 index 4979d88fe..000000000 --- a/local_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ /dev/null @@ -1,108 +0,0 @@ -import inspect -import warnings -from typing import Optional, Tuple, Union - -import torch - -from ...models import UNet2DModel, VQModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from ...schedulers import DDIMScheduler - - -class LDMPipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - vqvae ([`VQModel`]): - Vector-quantized (VQ) Model to encode and decode images to and from latent representations. - unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latens. - """ - - def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): - super().__init__() - scheduler = scheduler.set_format("pt") - self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - batch_size: int = 1, - generator: Optional[torch.Generator] = None, - eta: float = 0.0, - num_inference_steps: int = 50, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ) -> Union[Tuple, ImagePipelineOutput]: - - r""" - Args: - batch_size (`int`, *optional*, defaults to 1): - Number of images to generate. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. - """ - - if "torch_device" in kwargs: - device = kwargs.pop("torch_device") - warnings.warn( - "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." - " Consider using `pipe.to(torch_device)` instead." - ) - - # Set device as before (to be removed in 0.3.0) - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.to(device) - - latents = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - ) - latents = latents.to(self.device) - - self.scheduler.set_timesteps(num_inference_steps) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - - extra_kwargs = {} - if accepts_eta: - extra_kwargs["eta"] = eta - - for t in self.progress_bar(self.scheduler.timesteps): - # predict the noise residual - noise_prediction = self.unet(latents, t).sample - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample - - # decode the image latents with the VAE - image = self.vqvae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/local_diffusers/pipelines/pndm/__init__.py b/local_diffusers/pipelines/pndm/__init__.py deleted file mode 100644 index 6fc46aaab..000000000 --- a/local_diffusers/pipelines/pndm/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# flake8: noqa -from .pipeline_pndm import PNDMPipeline diff --git a/local_diffusers/pipelines/pndm/pipeline_pndm.py b/local_diffusers/pipelines/pndm/pipeline_pndm.py deleted file mode 100644 index f3dff1a9a..000000000 --- a/local_diffusers/pipelines/pndm/pipeline_pndm.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -# limitations under the License. - - -import warnings -from typing import Optional, Tuple, Union - -import torch - -from ...models import UNet2DModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from ...schedulers import PNDMScheduler - - -class PNDMPipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image. - """ - - unet: UNet2DModel - scheduler: PNDMScheduler - - def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): - super().__init__() - scheduler = scheduler.set_format("pt") - self.register_modules(unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - batch_size: int = 1, - num_inference_steps: int = 50, - generator: Optional[torch.Generator] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ) -> Union[ImagePipelineOutput, Tuple]: - r""" - Args: - batch_size (`int`, `optional`, defaults to 1): The number of images to generate. - num_inference_steps (`int`, `optional`, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - generator (`torch.Generator`, `optional`): A [torch - generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose - between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. - return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a - [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. - """ - # For more information on the sampling method you can take a look at Algorithm 2 of - # the official paper: https://arxiv.org/pdf/2202.09778.pdf - - if "torch_device" in kwargs: - device = kwargs.pop("torch_device") - warnings.warn( - "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." - " Consider using `pipe.to(torch_device)` instead." - ) - - # Set device as before (to be removed in 0.3.0) - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.to(device) - - # Sample gaussian noise to begin loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - ) - image = image.to(self.device) - - self.scheduler.set_timesteps(num_inference_steps) - for t in self.progress_bar(self.scheduler.timesteps): - model_output = self.unet(image, t).sample - - image = self.scheduler.step(model_output, t, image).prev_sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/local_diffusers/pipelines/score_sde_ve/__init__.py b/local_diffusers/pipelines/score_sde_ve/__init__.py deleted file mode 100644 index 000d61f6e..000000000 --- a/local_diffusers/pipelines/score_sde_ve/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# flake8: noqa -from .pipeline_score_sde_ve import ScoreSdeVePipeline diff --git a/local_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/local_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py deleted file mode 100644 index 604e2b54c..000000000 --- a/local_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env python3 -import warnings -from typing import Optional, Tuple, Union - -import torch - -from ...models import UNet2DModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from ...schedulers import ScoreSdeVeScheduler - - -class ScoreSdeVePipeline(DiffusionPipeline): - r""" - Parameters: - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): - The [`ScoreSdeVeScheduler`] scheduler to be used in combination with `unet` to denoise the encoded image. - """ - unet: UNet2DModel - scheduler: ScoreSdeVeScheduler - - def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline): - super().__init__() - self.register_modules(unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - batch_size: int = 1, - num_inference_steps: int = 2000, - generator: Optional[torch.Generator] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ) -> Union[ImagePipelineOutput, Tuple]: - r""" - Args: - batch_size (`int`, *optional*, defaults to 1): - The number of images to generate. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. - """ - - if "torch_device" in kwargs: - device = kwargs.pop("torch_device") - warnings.warn( - "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." - " Consider using `pipe.to(torch_device)` instead." - ) - - # Set device as before (to be removed in 0.3.0) - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.to(device) - - img_size = self.unet.config.sample_size - shape = (batch_size, 3, img_size, img_size) - - model = self.unet - - sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max - sample = sample.to(self.device) - - self.scheduler.set_timesteps(num_inference_steps) - self.scheduler.set_sigmas(num_inference_steps) - - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): - sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) - - # correction step - for _ in range(self.scheduler.correct_steps): - model_output = self.unet(sample, sigma_t).sample - sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample - - # prediction step - model_output = model(sample, sigma_t).sample - output = self.scheduler.step_pred(model_output, t, sample, generator=generator) - - sample, sample_mean = output.prev_sample, output.prev_sample_mean - - sample = sample_mean.clamp(0, 1) - sample = sample.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - sample = self.numpy_to_pil(sample) - - if not return_dict: - return (sample,) - - return ImagePipelineOutput(images=sample) diff --git a/local_diffusers/pipelines/stable_diffusion/__init__.py b/local_diffusers/pipelines/stable_diffusion/__init__.py deleted file mode 100644 index 5ffda93f1..000000000 --- a/local_diffusers/pipelines/stable_diffusion/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -from dataclasses import dataclass -from typing import List, Union - -import numpy as np - -import PIL -from PIL import Image - -from ...utils import BaseOutput, is_onnx_available, is_transformers_available - - -@dataclass -class StableDiffusionPipelineOutput(BaseOutput): - """ - Output class for Stable Diffusion pipelines. - - Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - nsfw_content_detected (`List[bool]`) - List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content. - """ - - images: Union[List[PIL.Image.Image], np.ndarray] - nsfw_content_detected: List[bool] - - -if is_transformers_available(): - from .pipeline_stable_diffusion import StableDiffusionPipeline - from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline - from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline - from .safety_checker import StableDiffusionSafetyChecker - -if is_transformers_available() and is_onnx_available(): - from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline diff --git a/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py deleted file mode 100644 index 8e3199b44..000000000 --- a/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ /dev/null @@ -1,397 +0,0 @@ -# Modification of the original file by O. Teytaud for facilitating genetic stable diffusion. - -import inspect -import os -import numpy as np -import random -import warnings -from typing import List, Optional, Union - -import torch - -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -class StableDiffusionPipeline(DiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offsensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - ): - super().__init__() - scheduler = scheduler.set_format("pt") - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - -# def get_latent(self, image): -# return self.vae.encode(image) - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - - if "torch_device" in kwargs: - device = kwargs.pop("torch_device") - warnings.warn( - "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." - " Consider using `pipe.to(torch_device)` instead." - ) - - # Set device as before (to be removed in 0.3.0) - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.to(device) - - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - # get prompt text embeddings - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - max_length = text_input.input_ids.shape[-1] - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - # get the initial random noise unless the user supplied it - - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_device = "cpu" if self.device.type == "mps" else self.device - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) - latents_intermediate_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) - speedup = 1 - if latents is None: - latents = torch.randn( - latents_intermediate_shape, - generator=generator, - device=latents_device, - ) - if len(os.environ["forcedlatent"]) > 10: - stri = os.environ["forcedlatent"] - print(f"we get a forcing for the latent z: {stri[:20]}.") - if len(eval(stri)) == 1: - stri = str(eval(stri)[0]) - speedup = 1 - latents = np.array(list(eval(stri))).flatten() - #latents = latents + np.exp(0.1 * np.random.randn()) * np.random.rand(len(latents)) - #latents = np.sqrt(len(latents) / np.sum(latents ** 2)) * latents - #latents = np.sqrt(len(latents)) * latents / np.sqrt(np.sum(latents ** 2)) - print(f"As an array, this is {latents[:10]}") - print(f"immediately after loading latent ==> {sum(latents.flatten()**2) / len(latents.flatten())}") - latents = torch.from_numpy(latents.reshape((1,4,64,64))).float().to(latents_device) - os.environ["forcedlatent"] = "" - good = eval(os.environ["good"]) - bad = eval(os.environ["bad"]) - print(f"{len(good)} good and {len(bad)} bad") - i_believe_in_evolution = len(good) > 0 and len(bad) > 10 - print(f"I believe in evolution = {i_believe_in_evolution}") - if i_believe_in_evolution: - from sklearn import tree - from sklearn.neural_network import MLPClassifier - #from sklearn.neighbors import NearestCentroid - from sklearn.linear_model import LogisticRegression - #z = (np.random.randn(4*64*64)) - z = latents.cpu().numpy().flatten() - if os.environ.get("skl", "tree") == "tree": - clf = tree.DecisionTreeClassifier()#min_samples_split=0.1) - elif os.environ.get("skl", "tree") == "logit": - clf = LogisticRegression() - else: - clf = MLPClassifier(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(5, 2), random_state=1) - #clf = NearestCentroid() - - - - X=good + bad - Y = [1] * len(good) + [0] * len(bad) - clf = clf.fit(X,Y) - epsilon = 0.0001 # for astronauts - epsilon = 1.0 - - def loss(x): - return clf.predict_proba([x])[0][0] # for astronauts - #return clf.predict_proba([(1-epsilon)*z+epsilon*x])[0][0] # for astronauts - #return clf.predict_proba([z+epsilon*x])[0][0] - - - budget = int(os.environ.get("budget", "300")) - if i_believe_in_evolution and budget > 20: - import nevergrad as ng - #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) - #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget) - optim_class = ng.optimizers.registry[os.environ.get("ngoptim", "DiscreteLenglerOnePlusOne")] - #nevergrad_optimizer = ng.optimizers.DiscreteLenglerOnePlusOne(len(z), budget) - nevergrad_optimizer = optim_class(len(z), budget) - #nevergrad_optimizer = ng.optimizers.DiscreteOnePlusOne(len(z), budget) -# for k in range(5): -# z1 = np.array(random.choice(good)) -# z2 = np.array(random.choice(good)) -# z3 = np.array(random.choice(good)) -# z4 = np.array(random.choice(good)) -# z5 = np.array(random.choice(good)) -# #z = 0.99 * z1 + 0.01 * (z2+z3+z4+z5)/4. -# z = 0.2 * (z1 + z2 + z3 + z4 + z5) -# mu = int(os.environ.get("mu", "5")) -# parents = [z1, z2, z3, z4, z5] -# weights = [np.exp(np.random.randn() - i * float(os.environ.get("decay", "1."))) for i in range(5)] -# z = weights[0] * z1 -# for u in range(mu): -# if u > 0: -# z += weights[u] * parents[u] -# z = (1. / sum(weights[:mu])) * z -# z = np.sqrt(len(z)) * z / np.linalg.norm(z) -# -# #for u in range(len(z)): -# # z[u] = random.choice([z1[u],z2[u],z3[u],z4[u],z5[u]]) -# nevergrad_optimizer.suggest - if len(os.environ["forcedlatent"]) > 0: - print("we get a forcing for the latent z.") - z0 = eval(os.environ["forcedlatent"]) - #nevergrad_optimizer.suggest(eval(os.environ["forcedlatent"])) - else: - z0 = z - for i in range(budget): - x = nevergrad_optimizer.ask() - z = z0 + float(os.environ.get("epsilon", "0.001")) * x.value - z = np.sqrt(len(z)) * z / np.linalg.norm(z) - l = loss(z) - nevergrad_optimizer.tell(x, l) - if np.log2(i+1) == int(np.log2(i+1)): - print(f"iteration {i} --> {l}") - print("var/variable = ", sum(z**2)/len(z)) - #z = (1.-epsilon) * z + epsilon * x / np.sqrt(np.sum(x ** 2)) - if l < 0.0000001 and os.environ.get("earlystop", "False") in ["true", "True"]: - print(f"we find proba(bad)={l}") - break - x = nevergrad_optimizer.recommend().value - z = z0 + float(os.environ.get("epsilon", "0.001")) * x - z = np.sqrt(len(z)) * z / np.linalg.norm(z) - latents = torch.from_numpy(z.reshape(latents_intermediate_shape)).float() #.half() - else: - if latents.shape != latents_intermediate_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_intermediate_shape}") - print(f"latent ==> {sum(latents.flatten()**2) / len(latents.flatten())}") - print(f"latent ==> {torch.max(latents)}") - print(f"latent ==> {torch.min(latents)}") - os.environ["latent_sd"] = str(list(latents.flatten().cpu().numpy())) - for i in [2, 3]: - latents = torch.repeat_interleave(latents, repeats=latents_shape[i] // latents_intermediate_shape[i], dim=i) #/ np.sqrt(np.sqrt(latents_shape[i] // latents_intermediate_shape[i])) - latents = latents.float().to(self.device) - - # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - if accepts_offset: - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps // speedup, **extra_set_kwargs) - - # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents * self.scheduler.sigmas[0] - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # scale and decode the image latents with vae - #os.environ["latent_sd"] = str(list(latents.flatten().cpu().detach().numpy())) - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - - # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py deleted file mode 100644 index 475ceef4f..000000000 --- a/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ /dev/null @@ -1,291 +0,0 @@ -import inspect -from typing import List, Optional, Union - -import numpy as np -import torch - -import PIL -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -def preprocess(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 - - -class StableDiffusionImg2ImgPipeline(DiffusionPipeline): - r""" - Pipeline for text-guided image to image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offsensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - ): - super().__init__() - scheduler = scheduler.set_format("pt") - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `set_attention_slice` - self.enable_attention_slice(None) - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - init_image: Union[torch.FloatTensor, PIL.Image.Image], - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - init_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. - `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - offset = 0 - if accepts_offset: - offset = 1 - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - - if not isinstance(init_image, torch.FloatTensor): - init_image = preprocess(init_image) - - # encode the init image into latents and scale the latents - init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents - - # expand init_latents for batch_size - init_latents = torch.cat([init_latents] * batch_size) - - # get the original timestep using init_timestep - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - if isinstance(self.scheduler, LMSDiscreteScheduler): - timesteps = torch.tensor( - [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device - ) - else: - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) - - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=self.device) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device) - - # get prompt text embeddings - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - max_length = text_input.input_ids.shape[-1] - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - latents = init_latents - - t_start = max(num_inference_steps - init_timestep + offset, 0) - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])): - t_index = t_start + i - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - - # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[t_index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) - latent_model_input = latent_model_input.to(self.unet.dtype) - t = t.to(self.unet.dtype) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents.to(self.vae.dtype)).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - - # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py deleted file mode 100644 index 05ea84ae0..000000000 --- a/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ /dev/null @@ -1,309 +0,0 @@ -import inspect -from typing import List, Optional, Union - -import numpy as np -import torch - -import PIL -from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, PNDMScheduler -from ...utils import logging -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) - - -def preprocess_image(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 - - -def preprocess_mask(mask): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask - - -class StableDiffusionInpaintPipeline(DiffusionPipeline): - r""" - Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offsensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler], - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - ): - super().__init__() - scheduler = scheduler.set_format("pt") - logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `set_attention_slice` - self.enable_attention_slice(None) - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - init_image: Union[torch.FloatTensor, PIL.Image.Image], - mask_image: Union[torch.FloatTensor, PIL.Image.Image], - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - init_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. This is the image whose masked region will be inpainted. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be - converted to a single channel (luminance) before use. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` - is 1, the denoising process will be run on the masked area for the full number of iterations specified - in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more - noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. - num_inference_steps (`int`, *optional*, defaults to 50): - The reference number of denoising steps. More denoising steps usually lead to a higher quality image at - the expense of slower inference. This parameter will be modulated by `strength`, as explained above. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - offset = 0 - if accepts_offset: - offset = 1 - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - - # preprocess image - init_image = preprocess_image(init_image).to(self.device) - - # encode the init image into latents and scale the latents - init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - - init_latents = 0.18215 * init_latents - - # Expand init_latents for batch_size - init_latents = torch.cat([init_latents] * batch_size) - init_latents_orig = init_latents - - # preprocess mask - mask = preprocess_mask(mask_image).to(self.device) - mask = torch.cat([mask] * batch_size) - - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and init_image should be the same size!") - - # get the original timestep using init_timestep - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) - - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=self.device) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) - - # get prompt text embeddings - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - max_length = text_input.input_ids.shape[-1] - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - latents = init_latents - t_start = max(num_inference_steps - init_timestep + offset, 0) - for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - - # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py deleted file mode 100644 index 7ff3ff22f..000000000 --- a/local_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ /dev/null @@ -1,165 +0,0 @@ -import inspect -from typing import List, Optional, Union - -import numpy as np - -from transformers import CLIPFeatureExtractor, CLIPTokenizer - -from ...onnx_utils import OnnxRuntimeModel -from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from . import StableDiffusionPipelineOutput - - -class StableDiffusionOnnxPipeline(DiffusionPipeline): - vae_decoder: OnnxRuntimeModel - text_encoder: OnnxRuntimeModel - tokenizer: CLIPTokenizer - unet: OnnxRuntimeModel - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] - safety_checker: OnnxRuntimeModel - feature_extractor: CLIPFeatureExtractor - - def __init__( - self, - vae_decoder: OnnxRuntimeModel, - text_encoder: OnnxRuntimeModel, - tokenizer: CLIPTokenizer, - unet: OnnxRuntimeModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: OnnxRuntimeModel, - feature_extractor: CLIPFeatureExtractor, - ): - super().__init__() - scheduler = scheduler.set_format("np") - self.register_modules( - vae_decoder=vae_decoder, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - - def __call__( - self, - prompt: Union[str, List[str]], - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - eta: Optional[float] = 0.0, - latents: Optional[np.ndarray] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ): - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - # get prompt text embeddings - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0] - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - max_length = text_input.input_ids.shape[-1] - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" - ) - uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) - - # get the initial random noise unless the user supplied it - latents_shape = (batch_size, 4, height // 8, width // 8) - if latents is None: - latents = np.random.randn(*latents_shape).astype(np.float32) - elif latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - - # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - if accepts_offset: - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - - # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents * self.scheduler.sigmas[0] - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) - - # predict the noise residual - noise_pred = self.unet( - sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings - ) - noise_pred = noise_pred[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - image = self.vae_decoder(latent_sample=latents)[0] - - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) - - # run safety checker - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") - image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/local_diffusers/pipelines/stable_diffusion/safety_checker.py b/local_diffusers/pipelines/stable_diffusion/safety_checker.py deleted file mode 100644 index 3ebc05c91..000000000 --- a/local_diffusers/pipelines/stable_diffusion/safety_checker.py +++ /dev/null @@ -1,106 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn - -from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel - -from ...utils import logging - - -logger = logging.get_logger(__name__) - - -def cosine_distance(image_embeds, text_embeds): - normalized_image_embeds = nn.functional.normalize(image_embeds) - normalized_text_embeds = nn.functional.normalize(text_embeds) - return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) - - -class StableDiffusionSafetyChecker(PreTrainedModel): - config_class = CLIPConfig - - def __init__(self, config: CLIPConfig): - super().__init__(config) - - self.vision_model = CLIPVisionModel(config.vision_config) - self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) - - self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) - self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) - - self.register_buffer("concept_embeds_weights", torch.ones(17)) - self.register_buffer("special_care_embeds_weights", torch.ones(3)) - - @torch.no_grad() - def forward(self, clip_input, images): - pooled_output = self.vision_model(clip_input)[1] # pooled_output - image_embeds = self.visual_projection(pooled_output) - - special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy() - cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy() - - result = [] - batch_size = image_embeds.shape[0] - for i in range(batch_size): - result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} - - # increase this value to create a stronger `nfsw` filter - # at the cost of increasing the possibility of filtering benign images - adjustment = 0.0 - - for concet_idx in range(len(special_cos_dist[0])): - concept_cos = special_cos_dist[i][concet_idx] - concept_threshold = self.special_care_embeds_weights[concet_idx].item() - result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["special_scores"][concet_idx] > 0: - result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]}) - adjustment = 0.01 - - for concet_idx in range(len(cos_dist[0])): - concept_cos = cos_dist[i][concet_idx] - concept_threshold = self.concept_embeds_weights[concet_idx].item() - result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["concept_scores"][concet_idx] > 0: - result_img["bad_concepts"].append(concet_idx) - - result.append(result_img) - - has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] - - #for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): - # if has_nsfw_concept: - # images[idx] = np.zeros(images[idx].shape) # black image -# -# if any(has_nsfw_concepts): -# logger.warning( -# "Potential NSFW content was detected in one or more images. A black image will be returned instead." -# " Try again with a different prompt and/or seed." -# ) - - return images, has_nsfw_concepts - - @torch.inference_mode() - def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): - pooled_output = self.vision_model(clip_input)[1] # pooled_output - image_embeds = self.visual_projection(pooled_output) - - special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) - cos_dist = cosine_distance(image_embeds, self.concept_embeds) - - # increase this value to create a stronger `nsfw` filter - # at the cost of increasing the possibility of filtering benign images - adjustment = 0.0 - - special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment - # special_scores = special_scores.round(decimals=3) - special_care = torch.any(special_scores > 0, dim=1) - special_adjustment = special_care * 0.01 - special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) - - concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment - # concept_scores = concept_scores.round(decimals=3) - has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) - - images[has_nsfw_concepts] = 0.0 # black image - - return images, has_nsfw_concepts diff --git a/local_diffusers/pipelines/stochastic_karras_ve/__init__.py b/local_diffusers/pipelines/stochastic_karras_ve/__init__.py deleted file mode 100644 index db2582043..000000000 --- a/local_diffusers/pipelines/stochastic_karras_ve/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# flake8: noqa -from .pipeline_stochastic_karras_ve import KarrasVePipeline diff --git a/local_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/local_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py deleted file mode 100644 index 15266544d..000000000 --- a/local_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +++ /dev/null @@ -1,129 +0,0 @@ -#!/usr/bin/env python3 -import warnings -from typing import Optional, Tuple, Union - -import torch - -from ...models import UNet2DModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from ...schedulers import KarrasVeScheduler - - -class KarrasVePipeline(DiffusionPipeline): - r""" - Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and - the VE column of Table 1 from [1] for reference. - - [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." - https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic - differential equations." https://arxiv.org/abs/2011.13456 - - Parameters: - unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. - scheduler ([`KarrasVeScheduler`]): - Scheduler for the diffusion process to be used in combination with `unet` to denoise the encoded image. - """ - - # add type hints for linting - unet: UNet2DModel - scheduler: KarrasVeScheduler - - def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): - super().__init__() - scheduler = scheduler.set_format("pt") - self.register_modules(unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - batch_size: int = 1, - num_inference_steps: int = 50, - generator: Optional[torch.Generator] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ) -> Union[Tuple, ImagePipelineOutput]: - r""" - Args: - batch_size (`int`, *optional*, defaults to 1): - The number of images to generate. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. - """ - if "torch_device" in kwargs: - device = kwargs.pop("torch_device") - warnings.warn( - "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." - " Consider using `pipe.to(torch_device)` instead." - ) - - # Set device as before (to be removed in 0.3.0) - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.to(device) - - img_size = self.unet.config.sample_size - shape = (batch_size, 3, img_size, img_size) - - model = self.unet - - # sample x_0 ~ N(0, sigma_0^2 * I) - sample = torch.randn(*shape) * self.scheduler.config.sigma_max - sample = sample.to(self.device) - - self.scheduler.set_timesteps(num_inference_steps) - - for t in self.progress_bar(self.scheduler.timesteps): - # here sigma_t == t_i from the paper - sigma = self.scheduler.schedule[t] - sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 - - # 1. Select temporarily increased noise level sigma_hat - # 2. Add new noise to move from sample_i to sample_hat - sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator) - - # 3. Predict the noise residual given the noise magnitude `sigma_hat` - # The model inputs and output are adjusted by following eq. (213) in [1]. - model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample - - # 4. Evaluate dx/dt at sigma_hat - # 5. Take Euler step from sigma to sigma_prev - step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat) - - if sigma_prev != 0: - # 6. Apply 2nd order correction - # The model inputs and output are adjusted by following eq. (213) in [1]. - model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample - step_output = self.scheduler.step_correct( - model_output, - sigma_hat, - sigma_prev, - sample_hat, - step_output.prev_sample, - step_output["derivative"], - ) - sample = step_output.prev_sample - - sample = (sample / 2 + 0.5).clamp(0, 1) - image = sample.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - image = self.numpy_to_pil(sample) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/local_diffusers/schedulers/__init__.py b/local_diffusers/schedulers/__init__.py deleted file mode 100644 index 20c25f351..000000000 --- a/local_diffusers/schedulers/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ..utils import is_scipy_available -from .scheduling_ddim import DDIMScheduler -from .scheduling_ddpm import DDPMScheduler -from .scheduling_karras_ve import KarrasVeScheduler -from .scheduling_pndm import PNDMScheduler -from .scheduling_sde_ve import ScoreSdeVeScheduler -from .scheduling_sde_vp import ScoreSdeVpScheduler -from .scheduling_utils import SchedulerMixin - - -if is_scipy_available(): - from .scheduling_lms_discrete import LMSDiscreteScheduler -else: - from ..utils.dummy_scipy_objects import * # noqa F403 diff --git a/local_diffusers/schedulers/scheduling_ddim.py b/local_diffusers/schedulers/scheduling_ddim.py deleted file mode 100644 index 894d63bf2..000000000 --- a/local_diffusers/schedulers/scheduling_ddim.py +++ /dev/null @@ -1,261 +0,0 @@ -# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion -# and https://github.com/hojonathanho/diffusion - -import math -from typing import Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput - - -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas, dtype=np.float32) - - -class DDIMScheduler(SchedulerMixin, ConfigMixin): - """ - Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising - diffusion probabilistic models (DDPMs) with non-Markovian guidance. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functios. - - For more details, see the original paper: https://arxiv.org/abs/2010.02502 - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): TODO - timestep_values (`np.ndarray`, optional): TODO - clip_sample (`bool`, default `True`): - option to clip predicted sample between -1 and 1 for numerical stability. - set_alpha_to_one (`bool`, default `True`): - if alpha for final step is 1 or the final alpha of the "non-previous" one. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. - - """ - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, - timestep_values: Optional[np.ndarray] = None, - clip_sample: bool = True, - set_alpha_to_one: bool = True, - tensor_format: str = "pt", - ): - if trained_betas is not None: - self.betas = np.asarray(trained_betas) - if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) - - # At every step in ddim, we are looking into the previous alphas_cumprod - # For the final step, there is no previous alphas_cumprod because we are already at 0 - # `set_alpha_to_one` decides whether we set this paratemer simply to one or - # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - - # setable values - self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - - def _get_variance(self, timestep, prev_timestep): - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) - - return variance - - def set_timesteps(self, num_inference_steps: int, offset: int = 0): - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - offset (`int`): TODO - """ - self.num_inference_steps = num_inference_steps - self.timesteps = np.arange( - 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps - )[::-1].copy() - self.timesteps += offset - self.set_format(tensor_format=self.tensor_format) - - def step( - self, - model_output: Union[torch.FloatTensor, np.ndarray], - timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], - eta: float = 0.0, - use_clipped_model_output: bool = False, - generator=None, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): - current instance of sample being created by diffusion process. - eta (`float`): weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`): TODO - generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf - # Ideally, read DDIM paper in-detail understanding - - # Notation ( -> - # - pred_noise_t -> e_theta(x_t, t) - # - pred_original_sample -> f_theta(x_t, t) or x_0 - # - std_dev_t -> sigma_t - # - eta -> η - # - pred_sample_direction -> "direction pointingc to x_t" - # - pred_prev_sample -> "x_t-1" - - # 1. get previous step value (=t-1) - prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps - - # 2. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - beta_prod_t = 1 - alpha_prod_t - - # 3. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - - # 4. Clip "predicted x_0" - if self.config.clip_sample: - pred_original_sample = self.clip(pred_original_sample, -1, 1) - - # 5. compute variance: "sigma_t(η)" -> see formula (16) - # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = self._get_variance(timestep, prev_timestep) - std_dev_t = eta * variance ** (0.5) - - if use_clipped_model_output: - # the model_output is always re-derived from the clipped x_0 in Glide - model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - - # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output - - # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction - - if eta > 0: - device = model_output.device if torch.is_tensor(model_output) else "cpu" - noise = torch.randn(model_output.shape, generator=generator).to(device) - variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise - - if not torch.is_tensor(model_output): - variance = variance.numpy() - - prev_sample = prev_sample + variance - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def add_noise( - self, - original_samples: Union[torch.FloatTensor, np.ndarray], - noise: Union[torch.FloatTensor, np.ndarray], - timesteps: Union[torch.IntTensor, np.ndarray], - ) -> Union[torch.FloatTensor, np.ndarray]: - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/local_diffusers/schedulers/scheduling_ddpm.py b/local_diffusers/schedulers/scheduling_ddpm.py deleted file mode 100644 index 4fbfb9038..000000000 --- a/local_diffusers/schedulers/scheduling_ddpm.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim - -import math -from typing import Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput - - -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas, dtype=np.float32) - - -class DDPMScheduler(SchedulerMixin, ConfigMixin): - """ - Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and - Langevin dynamics sampling. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functios. - - For more details, see the original paper: https://arxiv.org/abs/2006.11239 - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): TODO - variance_type (`str`): - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, - `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. - clip_sample (`bool`, default `True`): - option to clip predicted sample between -1 and 1 for numerical stability. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. - - """ - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, - variance_type: str = "fixed_small", - clip_sample: bool = True, - tensor_format: str = "pt", - ): - - if trained_betas is not None: - self.betas = np.asarray(trained_betas) - elif beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) - self.one = np.array(1.0) - - # setable values - self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - - self.variance_type = variance_type - - def set_timesteps(self, num_inference_steps: int): - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) - self.num_inference_steps = num_inference_steps - self.timesteps = np.arange( - 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps - )[::-1].copy() - self.set_format(tensor_format=self.tensor_format) - - def _get_variance(self, t, predicted_variance=None, variance_type=None): - alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one - - # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) - # and sample from it to get previous sample - # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample - variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] - - if variance_type is None: - variance_type = self.config.variance_type - - # hacks - were probs added for training stability - if variance_type == "fixed_small": - variance = self.clip(variance, min_value=1e-20) - # for rl-diffuser https://arxiv.org/abs/2205.09991 - elif variance_type == "fixed_small_log": - variance = self.log(self.clip(variance, min_value=1e-20)) - elif variance_type == "fixed_large": - variance = self.betas[t] - elif variance_type == "fixed_large_log": - # Glide max_log - variance = self.log(self.betas[t]) - elif variance_type == "learned": - return predicted_variance - elif variance_type == "learned_range": - min_log = variance - max_log = self.betas[t] - frac = (predicted_variance + 1) / 2 - variance = frac * max_log + (1 - frac) * min_log - - return variance - - def step( - self, - model_output: Union[torch.FloatTensor, np.ndarray], - timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], - predict_epsilon=True, - generator=None, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): - current instance of sample being created by diffusion process. - eta (`float`): weight of noise for added noise in diffusion step. - predict_epsilon (`bool`): - optional flag to use when model predicts the samples directly instead of the noise, epsilon. - generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - - """ - t = timestep - - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: - model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) - else: - predicted_variance = None - - # 1. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - # 2. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if predict_epsilon: - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - else: - pred_original_sample = model_output - - # 3. Clip "predicted x_0" - if self.config.clip_sample: - pred_original_sample = self.clip(pred_original_sample, -1, 1) - - # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t - current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t - - # 5. Compute predicted previous sample µ_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample - - # 6. Add noise - variance = 0 - if t > 0: - noise = self.randn_like(model_output, generator=generator) - variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise - - pred_prev_sample = pred_prev_sample + variance - - if not return_dict: - return (pred_prev_sample,) - - return SchedulerOutput(prev_sample=pred_prev_sample) - - def add_noise( - self, - original_samples: Union[torch.FloatTensor, np.ndarray], - noise: Union[torch.FloatTensor, np.ndarray], - timesteps: Union[torch.IntTensor, np.ndarray], - ) -> Union[torch.FloatTensor, np.ndarray]: - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/local_diffusers/schedulers/scheduling_karras_ve.py b/local_diffusers/schedulers/scheduling_karras_ve.py deleted file mode 100644 index 3a2370cfc..000000000 --- a/local_diffusers/schedulers/scheduling_karras_ve.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin - - -@dataclass -class KarrasVeOutput(BaseOutput): - """ - Output class for the scheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Derivate of predicted original image sample (x_0). - """ - - prev_sample: torch.FloatTensor - derivative: torch.FloatTensor - - -class KarrasVeScheduler(SchedulerMixin, ConfigMixin): - """ - Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and - the VE column of Table 1 from [1] for reference. - - [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." - https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic - differential equations." https://arxiv.org/abs/2011.13456 - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functios. - - For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of - Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the - optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. - - Args: - sigma_min (`float`): minimum noise magnitude - sigma_max (`float`): maximum noise magnitude - s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. - A reasonable range is [1.000, 1.011]. - s_churn (`float`): the parameter controlling the overall amount of stochasticity. - A reasonable range is [0, 100]. - s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). - A reasonable range is [0, 10]. - s_max (`float`): the end value of the sigma range where we add noise. - A reasonable range is [0.2, 80]. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. - - """ - - @register_to_config - def __init__( - self, - sigma_min: float = 0.02, - sigma_max: float = 100, - s_noise: float = 1.007, - s_churn: float = 80, - s_min: float = 0.05, - s_max: float = 50, - tensor_format: str = "pt", - ): - # setable values - self.num_inference_steps = None - self.timesteps = None - self.schedule = None # sigma(t_i) - - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - - def set_timesteps(self, num_inference_steps: int): - """ - Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - - """ - self.num_inference_steps = num_inference_steps - self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() - self.schedule = [ - (self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1))) - for i in self.timesteps - ] - self.schedule = np.array(self.schedule, dtype=np.float32) - - self.set_format(tensor_format=self.tensor_format) - - def add_noise_to_input( - self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None - ) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]: - """ - Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a - higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. - - TODO Args: - """ - if self.s_min <= sigma <= self.s_max: - gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1) - else: - gamma = 0 - - # sample eps ~ N(0, S_noise^2 * I) - eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) - sigma_hat = sigma + gamma * sigma - sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) - - return sample_hat, sigma_hat - - def step( - self, - model_output: Union[torch.FloatTensor, np.ndarray], - sigma_hat: float, - sigma_prev: float, - sample_hat: Union[torch.FloatTensor, np.ndarray], - return_dict: bool = True, - ) -> Union[KarrasVeOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. - sigma_hat (`float`): TODO - sigma_prev (`float`): TODO - sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). - Returns: - [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`: - [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - - """ - - pred_original_sample = sample_hat + sigma_hat * model_output - derivative = (sample_hat - pred_original_sample) / sigma_hat - sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative - - if not return_dict: - return (sample_prev, derivative) - - return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) - - def step_correct( - self, - model_output: Union[torch.FloatTensor, np.ndarray], - sigma_hat: float, - sigma_prev: float, - sample_hat: Union[torch.FloatTensor, np.ndarray], - sample_prev: Union[torch.FloatTensor, np.ndarray], - derivative: Union[torch.FloatTensor, np.ndarray], - return_dict: bool = True, - ) -> Union[KarrasVeOutput, Tuple]: - """ - Correct the predicted sample based on the output model_output of the network. TODO complete description - - Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. - sigma_hat (`float`): TODO - sigma_prev (`float`): TODO - sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO - sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO - derivative (`torch.FloatTensor` or `np.ndarray`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO - - """ - pred_original_sample = sample_prev + sigma_prev * model_output - derivative_corr = (sample_prev - pred_original_sample) / sigma_prev - sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) - - if not return_dict: - return (sample_prev, derivative) - - return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) - - def add_noise(self, original_samples, noise, timesteps): - raise NotImplementedError() diff --git a/local_diffusers/schedulers/scheduling_lms_discrete.py b/local_diffusers/schedulers/scheduling_lms_discrete.py deleted file mode 100644 index 1381587fe..000000000 --- a/local_diffusers/schedulers/scheduling_lms_discrete.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Tuple, Union - -import numpy as np -import torch - -from scipy import integrate - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput - - -class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): - """ - Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by - Katherine Crowson: - https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functios. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear` or `scaled_linear`. - trained_betas (`np.ndarray`, optional): TODO - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, - `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. - timestep_values (`np.ndarry`, optional): TODO - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. - - """ - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, - timestep_values: Optional[np.ndarray] = None, - tensor_format: str = "pt", - ): - if trained_betas is not None: - self.betas = np.asarray(trained_betas) - if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) - - self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 - - # setable values - self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - self.derivatives = [] - - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - - def get_lms_coefficient(self, order, t, current_order): - """ - Compute a linear multistep coefficient. - - Args: - order (TODO): - t (TODO): - current_order (TODO): - """ - - def lms_derivative(tau): - prod = 1.0 - for k in range(order): - if current_order == k: - continue - prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k]) - return prod - - integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0] - - return integrated_coeff - - def set_timesteps(self, num_inference_steps: int): - """ - Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - self.num_inference_steps = num_inference_steps - self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) - - low_idx = np.floor(self.timesteps).astype(int) - high_idx = np.ceil(self.timesteps).astype(int) - frac = np.mod(self.timesteps, 1.0) - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] - self.sigmas = np.concatenate([sigmas, [0.0]]) - - self.derivatives = [] - - self.set_format(tensor_format=self.tensor_format) - - def step( - self, - model_output: Union[torch.FloatTensor, np.ndarray], - timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], - order: int = 4, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): - current instance of sample being created by diffusion process. - order: coefficient for multi-step inference. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - - """ - sigma = self.sigmas[timestep] - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - pred_original_sample = sample - sigma * model_output - - # 2. Convert to an ODE derivative - derivative = (sample - pred_original_sample) / sigma - self.derivatives.append(derivative) - if len(self.derivatives) > order: - self.derivatives.pop(0) - - # 3. Compute linear multistep coefficients - order = min(timestep + 1, order) - lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)] - - # 4. Compute previous sample based on the derivatives path - prev_sample = sample + sum( - coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives)) - ) - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def add_noise( - self, - original_samples: Union[torch.FloatTensor, np.ndarray], - noise: Union[torch.FloatTensor, np.ndarray], - timesteps: Union[torch.IntTensor, np.ndarray], - ) -> Union[torch.FloatTensor, np.ndarray]: - sigmas = self.match_shape(self.sigmas[timesteps], noise) - noisy_samples = original_samples + noise * sigmas - - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/local_diffusers/schedulers/scheduling_pndm.py b/local_diffusers/schedulers/scheduling_pndm.py deleted file mode 100644 index b43d88bba..000000000 --- a/local_diffusers/schedulers/scheduling_pndm.py +++ /dev/null @@ -1,378 +0,0 @@ -# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim - -import math -from typing import Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput - - -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas, dtype=np.float32) - - -class PNDMScheduler(SchedulerMixin, ConfigMixin): - """ - Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, - namely Runge-Kutta method and a linear multi-step method. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functios. - - For more details, see the original paper: https://arxiv.org/abs/2202.09778 - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): TODO - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays - skip_prk_steps (`bool`): - allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required - before plms steps; defaults to `False`. - - """ - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, - tensor_format: str = "pt", - skip_prk_steps: bool = False, - ): - if trained_betas is not None: - self.betas = np.asarray(trained_betas) - if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) - - self.one = np.array(1.0) - - # For now we only support F-PNDM, i.e. the runge-kutta method - # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf - # mainly at formula (9), (12), (13) and the Algorithm 2. - self.pndm_order = 4 - - # running values - self.cur_model_output = 0 - self.counter = 0 - self.cur_sample = None - self.ets = [] - - # setable values - self.num_inference_steps = None - self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - self._offset = 0 - self.prk_timesteps = None - self.plms_timesteps = None - self.timesteps = None - - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - - def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor: - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - offset (`int`): TODO - """ - self.num_inference_steps = num_inference_steps - self._timesteps = list( - range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) - ) - self._offset = offset - self._timesteps = np.array([t + self._offset for t in self._timesteps]) - - if self.config.skip_prk_steps: - # for some models like stable diffusion the prk steps can/should be skipped to - # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation - # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 - self.prk_timesteps = np.array([]) - self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[ - ::-1 - ].copy() - else: - prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( - np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order - ) - self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() - self.plms_timesteps = self._timesteps[:-3][ - ::-1 - ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy - - self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) - - self.ets = [] - self.counter = 0 - self.set_format(tensor_format=self.tensor_format) - - def step( - self, - model_output: Union[torch.FloatTensor, np.ndarray], - timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. - - Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - - """ - if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: - return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) - else: - return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) - - def step_prk( - self, - model_output: Union[torch.FloatTensor, np.ndarray], - timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the - solution to the differential equation. - - Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 - prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1]) - timestep = self.prk_timesteps[self.counter // 4 * 4] - - if self.counter % 4 == 0: - self.cur_model_output += 1 / 6 * model_output - self.ets.append(model_output) - self.cur_sample = sample - elif (self.counter - 1) % 4 == 0: - self.cur_model_output += 1 / 3 * model_output - elif (self.counter - 2) % 4 == 0: - self.cur_model_output += 1 / 3 * model_output - elif (self.counter - 3) % 4 == 0: - model_output = self.cur_model_output + 1 / 6 * model_output - self.cur_model_output = 0 - - # cur_sample should not be `None` - cur_sample = self.cur_sample if self.cur_sample is not None else sample - - prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) - self.counter += 1 - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def step_plms( - self, - model_output: Union[torch.FloatTensor, np.ndarray], - timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple - times to approximate the solution. - - Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if not self.config.skip_prk_steps and len(self.ets) < 3: - raise ValueError( - f"{self.__class__} can only be run AFTER scheduler has been run " - "in 'prk' mode for at least 12 iterations " - "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " - "for more information." - ) - - prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0) - - if self.counter != 1: - self.ets.append(model_output) - else: - prev_timestep = timestep - timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps - - if len(self.ets) == 1 and self.counter == 0: - model_output = model_output - self.cur_sample = sample - elif len(self.ets) == 1 and self.counter == 1: - model_output = (model_output + self.ets[-1]) / 2 - sample = self.cur_sample - self.cur_sample = None - elif len(self.ets) == 2: - model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 - elif len(self.ets) == 3: - model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 - else: - model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) - - prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) - self.counter += 1 - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): - # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf - # this function computes x_(t−δ) using the formula of (9) - # Note that x_t needs to be added to both sides of the equation - - # Notation ( -> - # alpha_prod_t -> α_t - # alpha_prod_t_prev -> α_(t−δ) - # beta_prod_t -> (1 - α_t) - # beta_prod_t_prev -> (1 - α_(t−δ)) - # sample -> x_t - # model_output -> e_θ(x_t, t) - # prev_sample -> x_(t−δ) - alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset] - alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset] - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - # corresponds to (α_(t−δ) - α_t) divided by - # denominator of x_t in formula (9) and plus 1 - # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = - # sqrt(α_(t−δ)) / sqrt(α_t)) - sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) - - # corresponds to denominator of e_θ(x_t, t) in formula (9) - model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( - alpha_prod_t * beta_prod_t * alpha_prod_t_prev - ) ** (0.5) - - # full formula (9) - prev_sample = ( - sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff - ) - - return prev_sample - - def add_noise( - self, - original_samples: Union[torch.FloatTensor, np.ndarray], - noise: Union[torch.FloatTensor, np.ndarray], - timesteps: Union[torch.IntTensor, np.ndarray], - ) -> torch.Tensor: - # mps requires indices to be in the same device, so we use cpu as is the default with cuda - timesteps = timesteps.to(self.alphas_cumprod.device) - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/local_diffusers/schedulers/scheduling_sde_ve.py b/local_diffusers/schedulers/scheduling_sde_ve.py deleted file mode 100644 index e187f0796..000000000 --- a/local_diffusers/schedulers/scheduling_sde_ve.py +++ /dev/null @@ -1,283 +0,0 @@ -# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch - -import warnings -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin, SchedulerOutput - - -@dataclass -class SdeVeOutput(BaseOutput): - """ - Output class for the ScoreSdeVeScheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. - """ - - prev_sample: torch.FloatTensor - prev_sample_mean: torch.FloatTensor - - -class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): - """ - The variance exploding stochastic differential equation (SDE) scheduler. - - For more information, see the original paper: https://arxiv.org/abs/2011.13456 - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functios. - - Args: - snr (`float`): - coefficient weighting the step from the model_output sample (from the network) to the random noise. - sigma_min (`float`): - initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the - distribution of the data. - sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. - sampling_eps (`float`): the end value of sampling, where timesteps decrease progessively from 1 to - epsilon. - correct_steps (`int`): number of correction steps performed on a produced sample. - tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler. - """ - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 2000, - snr: float = 0.15, - sigma_min: float = 0.01, - sigma_max: float = 1348.0, - sampling_eps: float = 1e-5, - correct_steps: int = 1, - tensor_format: str = "pt", - ): - # setable values - self.timesteps = None - - self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) - - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - - def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): - """ - Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). - - """ - sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - self.timesteps = np.linspace(1, sampling_eps, num_inference_steps) - elif tensor_format == "pt": - self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps) - else: - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def set_sigmas( - self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None - ): - """ - Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. - - The sigmas control the weight of the `drift` and `diffusion` components of sample update. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - sigma_min (`float`, optional): - initial noise scale value (overrides value given at Scheduler instantiation). - sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). - sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). - - """ - sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min - sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max - sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps - if self.timesteps is None: - self.set_timesteps(num_inference_steps, sampling_eps) - - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)) - self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) - elif tensor_format == "pt": - self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)) - self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) - else: - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def get_adjacent_sigma(self, timesteps, t): - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1]) - elif tensor_format == "pt": - return torch.where( - timesteps == 0, - torch.zeros_like(t.to(timesteps.device)), - self.discrete_sigmas[timesteps - 1].to(timesteps.device), - ) - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def set_seed(self, seed): - warnings.warn( - "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a" - " generator instead.", - DeprecationWarning, - ) - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - np.random.seed(seed) - elif tensor_format == "pt": - torch.manual_seed(seed) - else: - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def step_pred( - self, - model_output: Union[torch.FloatTensor, np.ndarray], - timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], - generator: Optional[torch.Generator] = None, - return_dict: bool = True, - **kwargs, - ) -> Union[SdeVeOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): - current instance of sample being created by diffusion process. - generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if - `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if "seed" in kwargs and kwargs["seed"] is not None: - self.set_seed(kwargs["seed"]) - - if self.timesteps is None: - raise ValueError( - "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" - ) - - timestep = timestep * torch.ones( - sample.shape[0], device=sample.device - ) # torch.repeat_interleave(timestep, sample.shape[0]) - timesteps = (timestep * (len(self.timesteps) - 1)).long() - - # mps requires indices to be in the same device, so we use cpu as is the default with cuda - timesteps = timesteps.to(self.discrete_sigmas.device) - - sigma = self.discrete_sigmas[timesteps].to(sample.device) - adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) - drift = self.zeros_like(sample) - diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 - - # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) - # also equation 47 shows the analog from SDE models to ancestral sampling methods - drift = drift - diffusion[:, None, None, None] ** 2 * model_output - - # equation 6: sample noise for the diffusion term of - noise = self.randn_like(sample, generator=generator) - prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep - # TODO is the variable diffusion the correct scaling term for the noise? - prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g - - if not return_dict: - return (prev_sample, prev_sample_mean) - - return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean) - - def step_correct( - self, - model_output: Union[torch.FloatTensor, np.ndarray], - sample: Union[torch.FloatTensor, np.ndarray], - generator: Optional[torch.Generator] = None, - return_dict: bool = True, - **kwargs, - ) -> Union[SchedulerOutput, Tuple]: - """ - Correct the predicted sample based on the output model_output of the network. This is often run repeatedly - after making the prediction for the previous timestep. - - Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. - sample (`torch.FloatTensor` or `np.ndarray`): - current instance of sample being created by diffusion process. - generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if - `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if "seed" in kwargs and kwargs["seed"] is not None: - self.set_seed(kwargs["seed"]) - - if self.timesteps is None: - raise ValueError( - "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" - ) - - # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" - # sample noise for correction - noise = self.randn_like(sample, generator=generator) - - # compute step size from the model_output, the noise, and the snr - grad_norm = self.norm(model_output) - noise_norm = self.norm(noise) - step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 - step_size = step_size * torch.ones(sample.shape[0]).to(sample.device) - # self.repeat_scalar(step_size, sample.shape[0]) - - # compute corrected sample: model_output term and noise term - prev_sample_mean = sample + step_size[:, None, None, None] * model_output - prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def __len__(self): - return self.config.num_train_timesteps diff --git a/local_diffusers/schedulers/scheduling_sde_vp.py b/local_diffusers/schedulers/scheduling_sde_vp.py deleted file mode 100644 index 66e6ec661..000000000 --- a/local_diffusers/schedulers/scheduling_sde_vp.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch - -# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin - - -class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): - """ - The variance preserving stochastic differential equation (SDE) scheduler. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functios. - - For more information, see the original paper: https://arxiv.org/abs/2011.13456 - - UNDER CONSTRUCTION - - """ - - @register_to_config - def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): - - self.sigmas = None - self.discrete_sigmas = None - self.timesteps = None - - def set_timesteps(self, num_inference_steps): - self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) - - def step_pred(self, score, x, t): - if self.timesteps is None: - raise ValueError( - "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" - ) - - # TODO(Patrick) better comments + non-PyTorch - # postprocess model score - log_mean_coeff = ( - -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min - ) - std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) - score = -score / std[:, None, None, None] - - # compute - dt = -1.0 / len(self.timesteps) - - beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) - drift = -0.5 * beta_t[:, None, None, None] * x - diffusion = torch.sqrt(beta_t) - drift = drift - diffusion[:, None, None, None] ** 2 * score - x_mean = x + drift * dt - - # add noise - noise = torch.randn_like(x) - x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise - - return x, x_mean - - def __len__(self): - return self.config.num_train_timesteps diff --git a/local_diffusers/schedulers/scheduling_utils.py b/local_diffusers/schedulers/scheduling_utils.py deleted file mode 100644 index f2bcd73ac..000000000 --- a/local_diffusers/schedulers/scheduling_utils.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Union - -import numpy as np -import torch - -from ..utils import BaseOutput - - -SCHEDULER_CONFIG_NAME = "scheduler_config.json" - - -@dataclass -class SchedulerOutput(BaseOutput): - """ - Base class for the scheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - """ - - prev_sample: torch.FloatTensor - - -class SchedulerMixin: - """ - Mixin containing common functions for the schedulers. - """ - - config_name = SCHEDULER_CONFIG_NAME - ignore_for_config = ["tensor_format"] - - def set_format(self, tensor_format="pt"): - self.tensor_format = tensor_format - if tensor_format == "pt": - for key, value in vars(self).items(): - if isinstance(value, np.ndarray): - setattr(self, key, torch.from_numpy(value)) - - return self - - def clip(self, tensor, min_value=None, max_value=None): - tensor_format = getattr(self, "tensor_format", "pt") - - if tensor_format == "np": - return np.clip(tensor, min_value, max_value) - elif tensor_format == "pt": - return torch.clamp(tensor, min_value, max_value) - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def log(self, tensor): - tensor_format = getattr(self, "tensor_format", "pt") - - if tensor_format == "np": - return np.log(tensor) - elif tensor_format == "pt": - return torch.log(tensor) - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]): - """ - Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. - - Args: - values: an array or tensor of values to extract. - broadcast_array: an array with a larger shape of K dimensions with the batch - dimension equal to the length of timesteps. - Returns: - a tensor of shape [batch_size, 1, ...] where the shape has K dims. - """ - - tensor_format = getattr(self, "tensor_format", "pt") - values = values.flatten() - - while len(values.shape) < len(broadcast_array.shape): - values = values[..., None] - if tensor_format == "pt": - values = values.to(broadcast_array.device) - - return values - - def norm(self, tensor): - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - return np.linalg.norm(tensor) - elif tensor_format == "pt": - return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean() - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def randn_like(self, tensor, generator=None): - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - return np.random.randn(*np.shape(tensor)) - elif tensor_format == "pt": - # return torch.randn_like(tensor) - return torch.randn(tensor.shape, layout=tensor.layout, generator=generator).to(tensor.device) - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def zeros_like(self, tensor): - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - return np.zeros_like(tensor) - elif tensor_format == "pt": - return torch.zeros_like(tensor) - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") diff --git a/local_diffusers/testing_utils.py b/local_diffusers/testing_utils.py deleted file mode 100644 index ff8b6aa9b..000000000 --- a/local_diffusers/testing_utils.py +++ /dev/null @@ -1,61 +0,0 @@ -import os -import random -import unittest -from distutils.util import strtobool - -import torch - -from packaging import version - - -global_rng = random.Random() -torch_device = "cuda" if torch.cuda.is_available() else "cpu" -is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12") - -if is_torch_higher_equal_than_1_12: - torch_device = "mps" if torch.backends.mps.is_available() else torch_device - - -def parse_flag_from_env(key, default=False): - try: - value = os.environ[key] - except KeyError: - # KEY isn't set, default to `default`. - _value = default - else: - # KEY is set, convert it to True or False. - try: - _value = strtobool(value) - except ValueError: - # More values are supported, but let's keep the message simple. - raise ValueError(f"If set, {key} must be yes or no.") - return _value - - -_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) - - -def floats_tensor(shape, scale=1.0, rng=None, name=None): - """Creates a random float32 tensor""" - if rng is None: - rng = global_rng - - total_dims = 1 - for dim in shape: - total_dims *= dim - - values = [] - for _ in range(total_dims): - values.append(rng.random() * scale) - - return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() - - -def slow(test_case): - """ - Decorator marking a test as slow. - - Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. - - """ - return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) diff --git a/local_diffusers/training_utils.py b/local_diffusers/training_utils.py deleted file mode 100644 index fa1694161..000000000 --- a/local_diffusers/training_utils.py +++ /dev/null @@ -1,125 +0,0 @@ -import copy -import os -import random - -import numpy as np -import torch - - -def enable_full_determinism(seed: int): - """ - Helper function for reproducible behavior during distributed training. See - - https://pytorch.org/docs/stable/notes/randomness.html for pytorch - """ - # set seed first - set_seed(seed) - - # Enable PyTorch deterministic mode. This potentially requires either the environment - # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, - # depending on the CUDA version, so we set them both here - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" - torch.use_deterministic_algorithms(True) - - # Enable CUDNN deterministic mode - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - -def set_seed(seed: int): - """ - Args: - Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. - seed (`int`): The seed to set. - """ - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - # ^^ safe to call this function even if cuda is not available - - -class EMAModel: - """ - Exponential Moving Average of models weights - """ - - def __init__( - self, - model, - update_after_step=0, - inv_gamma=1.0, - power=2 / 3, - min_value=0.0, - max_value=0.9999, - device=None, - ): - """ - @crowsonkb's notes on EMA Warmup: - If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan - to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), - gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 - at 215.4k steps). - Args: - inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. - power (float): Exponential factor of EMA warmup. Default: 2/3. - min_value (float): The minimum EMA decay rate. Default: 0. - """ - - self.averaged_model = copy.deepcopy(model).eval() - self.averaged_model.requires_grad_(False) - - self.update_after_step = update_after_step - self.inv_gamma = inv_gamma - self.power = power - self.min_value = min_value - self.max_value = max_value - - if device is not None: - self.averaged_model = self.averaged_model.to(device=device) - - self.decay = 0.0 - self.optimization_step = 0 - - def get_decay(self, optimization_step): - """ - Compute the decay factor for the exponential moving average. - """ - step = max(0, optimization_step - self.update_after_step - 1) - value = 1 - (1 + step / self.inv_gamma) ** -self.power - - if step <= 0: - return 0.0 - - return max(self.min_value, min(value, self.max_value)) - - @torch.no_grad() - def step(self, new_model): - ema_state_dict = {} - ema_params = self.averaged_model.state_dict() - - self.decay = self.get_decay(self.optimization_step) - - for key, param in new_model.named_parameters(): - if isinstance(param, dict): - continue - try: - ema_param = ema_params[key] - except KeyError: - ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) - ema_params[key] = ema_param - - if not param.requires_grad: - ema_params[key].copy_(param.to(dtype=ema_param.dtype).data) - ema_param = ema_params[key] - else: - ema_param.mul_(self.decay) - ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) - - ema_state_dict[key] = ema_param - - for key, param in new_model.named_buffers(): - ema_state_dict[key] = param - - self.averaged_model.load_state_dict(ema_state_dict, strict=False) - self.optimization_step += 1 diff --git a/local_diffusers/utils/__init__.py b/local_diffusers/utils/__init__.py deleted file mode 100644 index c00a28e10..000000000 --- a/local_diffusers/utils/__init__.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2022 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os - -from .import_utils import ( - ENV_VARS_TRUE_AND_AUTO_VALUES, - ENV_VARS_TRUE_VALUES, - USE_JAX, - USE_TF, - USE_TORCH, - DummyObject, - is_flax_available, - is_inflect_available, - is_modelcards_available, - is_onnx_available, - is_scipy_available, - is_tf_available, - is_torch_available, - is_transformers_available, - is_unidecode_available, - requires_backends, -) -from .logging import get_logger -from .outputs import BaseOutput - - -logger = get_logger(__name__) - - -hf_cache_home = os.path.expanduser( - os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) -) -default_cache_path = os.path.join(hf_cache_home, "diffusers") - - -CONFIG_NAME = "config.json" -HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" -DIFFUSERS_CACHE = default_cache_path -DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" -HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) diff --git a/local_diffusers/utils/dummy_scipy_objects.py b/local_diffusers/utils/dummy_scipy_objects.py deleted file mode 100644 index 3706c5754..000000000 --- a/local_diffusers/utils/dummy_scipy_objects.py +++ /dev/null @@ -1,11 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -# flake8: noqa - -from ..utils import DummyObject, requires_backends - - -class LMSDiscreteScheduler(metaclass=DummyObject): - _backends = ["scipy"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["scipy"]) diff --git a/local_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py b/local_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py deleted file mode 100644 index 8c2aec218..000000000 --- a/local_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py +++ /dev/null @@ -1,10 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -# flake8: noqa -from ..utils import DummyObject, requires_backends - - -class GradTTSPipeline(metaclass=DummyObject): - _backends = ["transformers", "inflect", "unidecode"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers", "inflect", "unidecode"]) diff --git a/local_diffusers/utils/dummy_transformers_and_onnx_objects.py b/local_diffusers/utils/dummy_transformers_and_onnx_objects.py deleted file mode 100644 index 2e34b5ce0..000000000 --- a/local_diffusers/utils/dummy_transformers_and_onnx_objects.py +++ /dev/null @@ -1,11 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -# flake8: noqa - -from ..utils import DummyObject, requires_backends - - -class StableDiffusionOnnxPipeline(metaclass=DummyObject): - _backends = ["transformers", "onnx"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers", "onnx"]) diff --git a/local_diffusers/utils/dummy_transformers_objects.py b/local_diffusers/utils/dummy_transformers_objects.py deleted file mode 100644 index e05eb814d..000000000 --- a/local_diffusers/utils/dummy_transformers_objects.py +++ /dev/null @@ -1,32 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -# flake8: noqa - -from ..utils import DummyObject, requires_backends - - -class LDMTextToImagePipeline(metaclass=DummyObject): - _backends = ["transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers"]) - - -class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): - _backends = ["transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers"]) - - -class StableDiffusionInpaintPipeline(metaclass=DummyObject): - _backends = ["transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers"]) - - -class StableDiffusionPipeline(metaclass=DummyObject): - _backends = ["transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers"]) diff --git a/local_diffusers/utils/import_utils.py b/local_diffusers/utils/import_utils.py deleted file mode 100644 index 1f5e95ada..000000000 --- a/local_diffusers/utils/import_utils.py +++ /dev/null @@ -1,274 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Import utilities: Utilities related to imports and our lazy inits. -""" -import importlib.util -import os -import sys -from collections import OrderedDict - -from packaging import version - -from . import logging - - -# The package importlib_metadata is in a different place, depending on the python version. -if sys.version_info < (3, 8): - import importlib_metadata -else: - import importlib.metadata as importlib_metadata - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} -ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) - -USE_TF = os.environ.get("USE_TF", "AUTO").upper() -USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() -USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() - -_torch_version = "N/A" -if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: - _torch_available = importlib.util.find_spec("torch") is not None - if _torch_available: - try: - _torch_version = importlib_metadata.version("torch") - logger.info(f"PyTorch version {_torch_version} available.") - except importlib_metadata.PackageNotFoundError: - _torch_available = False -else: - logger.info("Disabling PyTorch because USE_TF is set") - _torch_available = False - - -_tf_version = "N/A" -if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: - _tf_available = importlib.util.find_spec("tensorflow") is not None - if _tf_available: - candidates = ( - "tensorflow", - "tensorflow-cpu", - "tensorflow-gpu", - "tf-nightly", - "tf-nightly-cpu", - "tf-nightly-gpu", - "intel-tensorflow", - "intel-tensorflow-avx512", - "tensorflow-rocm", - "tensorflow-macos", - "tensorflow-aarch64", - ) - _tf_version = None - # For the metadata, we have to look for both tensorflow and tensorflow-cpu - for pkg in candidates: - try: - _tf_version = importlib_metadata.version(pkg) - break - except importlib_metadata.PackageNotFoundError: - pass - _tf_available = _tf_version is not None - if _tf_available: - if version.parse(_tf_version) < version.parse("2"): - logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.") - _tf_available = False - else: - logger.info(f"TensorFlow version {_tf_version} available.") -else: - logger.info("Disabling Tensorflow because USE_TORCH is set") - _tf_available = False - - -if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: - _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None - if _flax_available: - try: - _jax_version = importlib_metadata.version("jax") - _flax_version = importlib_metadata.version("flax") - logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") - except importlib_metadata.PackageNotFoundError: - _flax_available = False -else: - _flax_available = False - - -_transformers_available = importlib.util.find_spec("transformers") is not None -try: - _transformers_version = importlib_metadata.version("transformers") - logger.debug(f"Successfully imported transformers version {_transformers_version}") -except importlib_metadata.PackageNotFoundError: - _transformers_available = False - - -_inflect_available = importlib.util.find_spec("inflect") is not None -try: - _inflect_version = importlib_metadata.version("inflect") - logger.debug(f"Successfully imported inflect version {_inflect_version}") -except importlib_metadata.PackageNotFoundError: - _inflect_available = False - - -_unidecode_available = importlib.util.find_spec("unidecode") is not None -try: - _unidecode_version = importlib_metadata.version("unidecode") - logger.debug(f"Successfully imported unidecode version {_unidecode_version}") -except importlib_metadata.PackageNotFoundError: - _unidecode_available = False - - -_modelcards_available = importlib.util.find_spec("modelcards") is not None -try: - _modelcards_version = importlib_metadata.version("modelcards") - logger.debug(f"Successfully imported modelcards version {_modelcards_version}") -except importlib_metadata.PackageNotFoundError: - _modelcards_available = False - - -_onnx_available = importlib.util.find_spec("onnxruntime") is not None -try: - _onnxruntime_version = importlib_metadata.version("onnxruntime") - logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") -except importlib_metadata.PackageNotFoundError: - _onnx_available = False - - -_scipy_available = importlib.util.find_spec("scipy") is not None -try: - _scipy_version = importlib_metadata.version("scipy") - logger.debug(f"Successfully imported transformers version {_scipy_version}") -except importlib_metadata.PackageNotFoundError: - _scipy_available = False - - -def is_torch_available(): - return _torch_available - - -def is_tf_available(): - return _tf_available - - -def is_flax_available(): - return _flax_available - - -def is_transformers_available(): - return _transformers_available - - -def is_inflect_available(): - return _inflect_available - - -def is_unidecode_available(): - return _unidecode_available - - -def is_modelcards_available(): - return _modelcards_available - - -def is_onnx_available(): - return _onnx_available - - -def is_scipy_available(): - return _scipy_available - - -# docstyle-ignore -FLAX_IMPORT_ERROR = """ -{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the -installation page: https://github.com/google/flax and follow the ones that match your environment. -""" - -# docstyle-ignore -INFLECT_IMPORT_ERROR = """ -{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install -inflect` -""" - -# docstyle-ignore -PYTORCH_IMPORT_ERROR = """ -{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the -installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. -""" - -# docstyle-ignore -ONNX_IMPORT_ERROR = """ -{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip -install onnxruntime` -""" - -# docstyle-ignore -SCIPY_IMPORT_ERROR = """ -{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install -scipy` -""" - -# docstyle-ignore -TENSORFLOW_IMPORT_ERROR = """ -{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the -installation page: https://www.tensorflow.org/install and follow the ones that match your environment. -""" - -# docstyle-ignore -TRANSFORMERS_IMPORT_ERROR = """ -{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip -install transformers` -""" - -# docstyle-ignore -UNIDECODE_IMPORT_ERROR = """ -{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install -Unidecode` -""" - - -BACKENDS_MAPPING = OrderedDict( - [ - ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), - ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), - ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), - ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), - ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), - ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), - ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), - ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), - ] -) - - -def requires_backends(obj, backends): - if not isinstance(backends, (list, tuple)): - backends = [backends] - - name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ - checks = (BACKENDS_MAPPING[backend] for backend in backends) - failed = [msg.format(name) for available, msg in checks if not available()] - if failed: - raise ImportError("".join(failed)) - - -class DummyObject(type): - """ - Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by - `requires_backend` each time a user tries to access any method of that class. - """ - - def __getattr__(cls, key): - if key.startswith("_"): - return super().__getattr__(cls, key) - requires_backends(cls, cls._backends) diff --git a/local_diffusers/utils/logging.py b/local_diffusers/utils/logging.py deleted file mode 100644 index 1f2d0227b..000000000 --- a/local_diffusers/utils/logging.py +++ /dev/null @@ -1,344 +0,0 @@ -# coding=utf-8 -# Copyright 2020 Optuna, Hugging Face -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Logging utilities.""" - -import logging -import os -import sys -import threading -from logging import CRITICAL # NOQA -from logging import DEBUG # NOQA -from logging import ERROR # NOQA -from logging import FATAL # NOQA -from logging import INFO # NOQA -from logging import NOTSET # NOQA -from logging import WARN # NOQA -from logging import WARNING # NOQA -from typing import Optional - -from tqdm import auto as tqdm_lib - - -_lock = threading.Lock() -_default_handler: Optional[logging.Handler] = None - -log_levels = { - "debug": logging.DEBUG, - "info": logging.INFO, - "warning": logging.WARNING, - "error": logging.ERROR, - "critical": logging.CRITICAL, -} - -_default_log_level = logging.WARNING - -_tqdm_active = True - - -def _get_default_logging_level(): - """ - If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is - not - fall back to `_default_log_level` - """ - env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None) - if env_level_str: - if env_level_str in log_levels: - return log_levels[env_level_str] - else: - logging.getLogger().warning( - f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " - f"has to be one of: { ', '.join(log_levels.keys()) }" - ) - return _default_log_level - - -def _get_library_name() -> str: - - return __name__.split(".")[0] - - -def _get_library_root_logger() -> logging.Logger: - - return logging.getLogger(_get_library_name()) - - -def _configure_library_root_logger() -> None: - - global _default_handler - - with _lock: - if _default_handler: - # This library has already configured the library root logger. - return - _default_handler = logging.StreamHandler() # Set sys.stderr as stream. - _default_handler.flush = sys.stderr.flush - - # Apply our default configuration to the library root logger. - library_root_logger = _get_library_root_logger() - library_root_logger.addHandler(_default_handler) - library_root_logger.setLevel(_get_default_logging_level()) - library_root_logger.propagate = False - - -def _reset_library_root_logger() -> None: - - global _default_handler - - with _lock: - if not _default_handler: - return - - library_root_logger = _get_library_root_logger() - library_root_logger.removeHandler(_default_handler) - library_root_logger.setLevel(logging.NOTSET) - _default_handler = None - - -def get_log_levels_dict(): - return log_levels - - -def get_logger(name: Optional[str] = None) -> logging.Logger: - """ - Return a logger with the specified name. - - This function is not supposed to be directly accessed unless you are writing a custom diffusers module. - """ - - if name is None: - name = _get_library_name() - - _configure_library_root_logger() - return logging.getLogger(name) - - -def get_verbosity() -> int: - """ - Return the current level for the 🤗 Diffusers' root logger as an int. - - Returns: - `int`: The logging level. - - - - 🤗 Diffusers has following logging levels: - - - 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` - - 40: `diffusers.logging.ERROR` - - 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN` - - 20: `diffusers.logging.INFO` - - 10: `diffusers.logging.DEBUG` - - """ - - _configure_library_root_logger() - return _get_library_root_logger().getEffectiveLevel() - - -def set_verbosity(verbosity: int) -> None: - """ - Set the verbosity level for the 🤗 Diffusers' root logger. - - Args: - verbosity (`int`): - Logging level, e.g., one of: - - - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` - - `diffusers.logging.ERROR` - - `diffusers.logging.WARNING` or `diffusers.logging.WARN` - - `diffusers.logging.INFO` - - `diffusers.logging.DEBUG` - """ - - _configure_library_root_logger() - _get_library_root_logger().setLevel(verbosity) - - -def set_verbosity_info(): - """Set the verbosity to the `INFO` level.""" - return set_verbosity(INFO) - - -def set_verbosity_warning(): - """Set the verbosity to the `WARNING` level.""" - return set_verbosity(WARNING) - - -def set_verbosity_debug(): - """Set the verbosity to the `DEBUG` level.""" - return set_verbosity(DEBUG) - - -def set_verbosity_error(): - """Set the verbosity to the `ERROR` level.""" - return set_verbosity(ERROR) - - -def disable_default_handler() -> None: - """Disable the default handler of the HuggingFace Diffusers' root logger.""" - - _configure_library_root_logger() - - assert _default_handler is not None - _get_library_root_logger().removeHandler(_default_handler) - - -def enable_default_handler() -> None: - """Enable the default handler of the HuggingFace Diffusers' root logger.""" - - _configure_library_root_logger() - - assert _default_handler is not None - _get_library_root_logger().addHandler(_default_handler) - - -def add_handler(handler: logging.Handler) -> None: - """adds a handler to the HuggingFace Diffusers' root logger.""" - - _configure_library_root_logger() - - assert handler is not None - _get_library_root_logger().addHandler(handler) - - -def remove_handler(handler: logging.Handler) -> None: - """removes given handler from the HuggingFace Diffusers' root logger.""" - - _configure_library_root_logger() - - assert handler is not None and handler not in _get_library_root_logger().handlers - _get_library_root_logger().removeHandler(handler) - - -def disable_propagation() -> None: - """ - Disable propagation of the library log outputs. Note that log propagation is disabled by default. - """ - - _configure_library_root_logger() - _get_library_root_logger().propagate = False - - -def enable_propagation() -> None: - """ - Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent - double logging if the root logger has been configured. - """ - - _configure_library_root_logger() - _get_library_root_logger().propagate = True - - -def enable_explicit_format() -> None: - """ - Enable explicit formatting for every HuggingFace Diffusers' logger. The explicit formatter is as follows: - ``` - [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE - ``` - All handlers currently bound to the root logger are affected by this method. - """ - handlers = _get_library_root_logger().handlers - - for handler in handlers: - formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") - handler.setFormatter(formatter) - - -def reset_format() -> None: - """ - Resets the formatting for HuggingFace Diffusers' loggers. - - All handlers currently bound to the root logger are affected by this method. - """ - handlers = _get_library_root_logger().handlers - - for handler in handlers: - handler.setFormatter(None) - - -def warning_advice(self, *args, **kwargs): - """ - This method is identical to `logger.warninging()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this - warning will not be printed - """ - no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False) - if no_advisory_warnings: - return - self.warning(*args, **kwargs) - - -logging.Logger.warning_advice = warning_advice - - -class EmptyTqdm: - """Dummy tqdm which doesn't do anything.""" - - def __init__(self, *args, **kwargs): # pylint: disable=unused-argument - self._iterator = args[0] if args else None - - def __iter__(self): - return iter(self._iterator) - - def __getattr__(self, _): - """Return empty function.""" - - def empty_fn(*args, **kwargs): # pylint: disable=unused-argument - return - - return empty_fn - - def __enter__(self): - return self - - def __exit__(self, type_, value, traceback): - return - - -class _tqdm_cls: - def __call__(self, *args, **kwargs): - if _tqdm_active: - return tqdm_lib.tqdm(*args, **kwargs) - else: - return EmptyTqdm(*args, **kwargs) - - def set_lock(self, *args, **kwargs): - self._lock = None - if _tqdm_active: - return tqdm_lib.tqdm.set_lock(*args, **kwargs) - - def get_lock(self): - if _tqdm_active: - return tqdm_lib.tqdm.get_lock() - - -tqdm = _tqdm_cls() - - -def is_progress_bar_enabled() -> bool: - """Return a boolean indicating whether tqdm progress bars are enabled.""" - global _tqdm_active - return bool(_tqdm_active) - - -def enable_progress_bar(): - """Enable tqdm progress bar.""" - global _tqdm_active - _tqdm_active = True - - -def disable_progress_bar(): - """Disable tqdm progress bar.""" - global _tqdm_active - _tqdm_active = False diff --git a/local_diffusers/utils/outputs.py b/local_diffusers/utils/outputs.py deleted file mode 100644 index b02f62d02..000000000 --- a/local_diffusers/utils/outputs.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Generic utilities -""" - -import warnings -from collections import OrderedDict -from dataclasses import fields -from typing import Any, Tuple - -import numpy as np - -from .import_utils import is_torch_available - - -def is_tensor(x): - """ - Tests if `x` is a `torch.Tensor` or `np.ndarray`. - """ - if is_torch_available(): - import torch - - if isinstance(x, torch.Tensor): - return True - - return isinstance(x, np.ndarray) - - -class BaseOutput(OrderedDict): - """ - Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a - tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular - python dictionary. - - - - You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple - before. - - - """ - - def __post_init__(self): - class_fields = fields(self) - - # Safety and consistency checks - if not len(class_fields): - raise ValueError(f"{self.__class__.__name__} has no fields.") - - for field in class_fields: - v = getattr(self, field.name) - if v is not None: - self[field.name] = v - - def __delitem__(self, *args, **kwargs): - raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") - - def setdefault(self, *args, **kwargs): - raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") - - def pop(self, *args, **kwargs): - raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") - - def update(self, *args, **kwargs): - raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") - - def __getitem__(self, k): - if isinstance(k, str): - inner_dict = {k: v for (k, v) in self.items()} - if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample": - warnings.warn( - "The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or" - " `'images'` instead.", - DeprecationWarning, - ) - return inner_dict["images"] - return inner_dict[k] - else: - return self.to_tuple()[k] - - def __setattr__(self, name, value): - if name in self.keys() and value is not None: - # Don't call self.__setitem__ to avoid recursion errors - super().__setitem__(name, value) - super().__setattr__(name, value) - - def __setitem__(self, key, value): - # Will raise a KeyException if needed - super().__setitem__(key, value) - # Don't call self.__setattr__ to avoid recursion errors - super().__setattr__(key, value) - - def to_tuple(self) -> Tuple[Any]: - """ - Convert self to a tuple containing all the attributes/keys that are not `None`. - """ - return tuple(self[k] for k in self.keys()) diff --git a/minisd.py b/minisd.py index 5010a160c..6fba53395 100644 --- a/minisd.py +++ b/minisd.py @@ -7,7 +7,7 @@ from PIL import Image from einops import rearrange, repeat from torch import autocast -from local_diffusers import StableDiffusionPipeline +from diffusers import StableDiffusionPipeline import webbrowser from deep_translator import GoogleTranslator from langdetect import detect @@ -34,7 +34,7 @@ #os.environ["enforcedlatent"] = "" os.environ["good"] = "[]" os.environ["bad"] = "[]" -num_iterations = 50 +num_iterations = 200 gs = 7.5 @@ -121,6 +121,8 @@ prompt = "Yann LeCun rides a dragon which spits fire on a cherry on a cake." prompt = "An armored Mark Zuckerberg fighting off a monster with bloody tentacles in the jungle with a light saber." prompt = "Cute woman, portrait, photo, red hair, green eyes, smiling." +prompt = "Photo of Tarzan as a lawyer with a tie and an octopus on his head." +prompt = "An armored bloody Yann Lecun has a lightsabar and fights a red tentacular monster." print(f"The prompt is {prompt}") @@ -205,11 +207,11 @@ def stop_all(list_of_files, list_of_latent, last_list_of_latent): pretty_print("Should we create animations ?") answer = input(" [y]es or [n]o or [j]ust the selection on the last panel ?") if "y" in answer or "Y" in answer or "j" in answer or "J" in answer: + assert len(list_of_files) == len(list_of_latent) if "j" in answer or "J" in answer: list_of_latent = last_list_of_latent pretty_print("Let us create animations!") - assert len(list_of_files) == len(list_of_latent) - for c in [0.5, 0.25, 0.125, 0.0625, 0.05, 0.04,0.03125]: + for c in sorted([0.5, 0.25, 0.125, 0.0625, 0.05, 0.04,0.03125]): for idx in range(len(list_of_files)): images = [] l = list_of_latent[idx].reshape(1,4,64,64) @@ -379,7 +381,7 @@ def load_img(path): if i > 0: #epsilon = 0.3 / 1.1**i #basic_new_fl = np.sqrt(len(new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl - epsilon = (i-1)/(llambda-1) #1.0 / 2**(2 + (llambda - i) / 6) + epsilon = (0.5 * (i-1)/(llambda-1))**3 #1.0 / 2**(2 + (llambda - i) / 6) new_fl = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(1*4*64*64) else: new_fl = basic_new_fl @@ -423,7 +425,7 @@ def load_img(path): scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2+Y/8)) # Button for early stopping - text2 = font.render(to_native(f'{len(all_selected)} chosen images! '), True, green, blue) + text2 = font.render(to_native(f'Total: {len(all_selected)} chosen images! '), True, green, blue) text2 = pygame.transform.rotate(text2, 90) scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) text2 = font.render(to_native('Click for stopping,'), True, green, blue) @@ -614,7 +616,7 @@ def load_img(path): all_selected += [selected_filename] all_selected_latent += [latent[index]] final_selection += [latent[index]] - text2 = font.render(to_native(f'{len(all_selected)} chosen images! '), True, green, blue) + text2 = font.render(to_native(f'==> {len(all_selected)} chosen images! '), True, green, blue) text2 = pygame.transform.rotate(text2, 90) scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) if index not in five_best and len(five_best) < 5: @@ -696,7 +698,7 @@ def load_img(path): #if a % 2 == 0: # forcedlatent -= np.random.rand() * sauron basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent - epsilon = 0.3 * (((a - len(good)) / (llambda - len(good) - 1)) ** 6) + epsilon = 0.1 * (((a + .5 - len(good)) / (llambda - len(good) - 1)) ** 6) forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent forcedlatents += [forcedlatent] From aeb88ff1c333b5e760ac2bb2a5b13b9ac542cbed Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 4 Oct 2022 07:33:50 +0200 Subject: [PATCH 62/76] fix --- minisd.py | 190 ++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 136 insertions(+), 54 deletions(-) diff --git a/minisd.py b/minisd.py index 6fba53395..e3ca9695f 100644 --- a/minisd.py +++ b/minisd.py @@ -1,5 +1,6 @@ import random import os +import time import torch import numpy as np import shutil @@ -34,8 +35,9 @@ #os.environ["enforcedlatent"] = "" os.environ["good"] = "[]" os.environ["bad"] = "[]" -num_iterations = 200 +num_iterations = 50 gs = 7.5 +voronoi_in_images = True @@ -55,6 +57,7 @@ all_selected_latent = [] final_selection = [] forcedlatents = [] +forcedgs = [] @@ -123,6 +126,14 @@ prompt = "Cute woman, portrait, photo, red hair, green eyes, smiling." prompt = "Photo of Tarzan as a lawyer with a tie and an octopus on his head." prompt = "An armored bloody Yann Lecun has a lightsabar and fights a red tentacular monster." +prompt = "Photo of a giant armored insect attacking a building. The building is broken. There are flames." +prompt = "Photo of Meg Myers, on the left, in Egyptian dress, fights Cthulhu (on the right) with a light saber. They stare at each other." +prompt = "Photo of a cute red panda." +prompt = "Photo of a cute smiling white-haired woman with pink eyes." +prompt = "A muscular Jesus with and assault rifle, a cap and and a light saber." +prompt = "A portrait of a cute smiling woman." +prompt = "A woman with black skin, red hair, egyptian dress, yellow eyes." +prompt = "Photo of a young cute black woman." print(f"The prompt is {prompt}") @@ -176,7 +187,9 @@ def singleeg(path_to_image): output_filename = path_to_image + ".SR.png" sr_image.save(output_filename) return output_filename + def singleeg2(path_to_image): + time.sleep(0.5*np.random.rand()) image = Image.open(path_to_image).convert('RGB') sr_device = device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Type before SR = {type(image)}") @@ -251,7 +264,7 @@ def stop_all(list_of_files, list_of_latent, last_list_of_latent): # images += [image_name] print(to_native(f"Base images created for perturbation={c} and file {list_of_files[idx]}")) #images = Parallel(n_jobs=8)(delayed(process)(i) for i in range(10)) - images = Parallel(n_jobs=16)(delayed(singleeg2)(image) for image in images) + images = Parallel(n_jobs=10)(delayed(singleeg2)(image) for image in images) frames = [Image.open(image) for image in images] frame_one = frames[0] @@ -311,11 +324,18 @@ def load_img(path): image = torch.from_numpy(image) return 2.*image - 1. +model = pipe.vae + +def img_to_latent(path): + init_image = 1.8 * load_img(path).to(device) + init_image = repeat(init_image, '1 ... -> b ...', b=1) + forced_latent = model.encode(init_image.to(device)).latent_dist.sample() + new_fl = forced_latent.cpu().detach().numpy().flatten() + new_fl = np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) + return new_fl + if len(image_name) > 0: pretty_print("Importing an image !") - import torchvision - #forced_latent = pipe.get_latent(torchvision.io.read_image(image_name).float()) - model = pipe.vae try: init_image = load_img(image_name).to(device) except: @@ -362,7 +382,8 @@ def load_img(path): new_base_init_image[0,0,:,:] /= divider new_base_init_image[0,2,:,:] /= divider - c = np.exp(np.random.randn() - 2) + c = np.exp(np.random.randn() - 5) + f = np.exp(-3. * np.random.rand()) init_image_shape = base_init_image.cpu().numpy().shape if i > 0 and not latent_found: init_image = new_base_init_image + torch.from_numpy(c * np.random.randn(np.prod(init_image_shape))).reshape(init_image_shape).float().to(device) @@ -373,20 +394,23 @@ def load_img(path): new_fl = np.asarray(eval(latent_str)) assert len(new_fl) > 1 else: - forced_latent = 6. * model.encode(init_image.to(device)).latent_dist.sample() + forced_latent = 1. * model.encode(init_image.to(device)).latent_dist.sample() new_fl = forced_latent.cpu().detach().numpy().flatten() basic_new_fl = new_fl #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl #new_fl = forced_latent + (1. / 1.1**(llambda-i)) * torch.from_numpy(np.random.randn(1*4*64*64).reshape(1,4,64,64)).float().to(device) #forcedlatents += [new_fl.cpu().detach().numpy()] if i > 0: #epsilon = 0.3 / 1.1**i - #basic_new_fl = np.sqrt(len(new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl - epsilon = (0.5 * (i-1)/(llambda-1))**3 #1.0 / 2**(2 + (llambda - i) / 6) + basic_new_fl = f * np.sqrt(len(new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl + epsilon = .7 * ((i-1)/(llambda-1)) #1.0 / 2**(2 + (llambda - i) / 6) + print(f"{i} -- {i % 7} {c} {f} {epsilon}") + # 1 -- 1 0.050020045300292804 0.0790648688521246 0.0 new_fl = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(1*4*64*64) else: new_fl = basic_new_fl - #new_fl = np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) + new_fl = 6. * np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) forcedlatents += [new_fl] #np.clip(new_fl, -3., 3.)] #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl] + forcedgs += [7.5] #np.random.choice([7.5, 15.0, 30.0, 60.0])] TODO #forcedlatents += [np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl] #print(f"{i} --> {forcedlatents[i][:10]}") @@ -446,7 +470,12 @@ def load_img(path): # os.environ["enforcedlatent"] = "" #with autocast("cuda"): # image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0] + previous_gs = gs + if k < len(forcedgs): + gs = forcedgs[k] image = latent_to_image(np.asarray(latent_forcing)) #eval(os.environ["forcedlatent"]))) + gs = previous_gs + images += [image] filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{iteration:05d}_{k:05d}.png" image.save(filename) @@ -465,10 +494,13 @@ def load_img(path): with open(filename + ".latent.txt", 'w') as f: f.write(f"{str_latent}") # In case of early stopping. + first_event = True for i in pygame.event.get(): if i.type == pygame.MOUSEBUTTONUP: - noise.say("Ok I stop") - noise.runAndWait() + if first_event: + noise.say("Ok I stop") + noise.runAndWait() + first_event = False pos = pygame.mouse.get_pos() index = 3 * (pos[0] // 300) + (pos[1] // 300) if pos[0] > X and pos[1] > Y /3 and pos[1] < 2*Y/3: @@ -480,7 +512,7 @@ def load_img(path): break early_stop = [(1,1)] satus = False - + forcedgs = [] # Stop the forcing from disk! #os.environ["enforcedlatent"] = "" # importing required library @@ -657,51 +689,101 @@ def load_img(path): os.environ["mu"] = str(len(indices)) forcedlatents = [] bad += [list(latent[u].flatten()) for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]] - sauron = 0 * latent[0] - for u in [u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]]: - sauron += latent[u] - sauron = (1 / len([u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]])) * sauron + #sauron = 0 * latent[0] + #for u in [u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]]: + # sauron += latent[u] + #sauron = (1 / len([u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]])) * sauron if len(bad) > 300: bad = bad[(len(bad) - 300):] print(to_native(f"{len(indices)} indices are selected.")) #print(f"indices = {indices}") - for a in range(llambda): - forcedlatent = np.zeros((4, 64, 64)) - os.environ["good"] = str(good) - os.environ["bad"] = str(bad) - coefficients = np.zeros(len(indices)) - for i in range(len(indices)): - coefficients[i] = np.exp(2. * np.random.randn()) - for i in range(64): - x = i / 63. - for j in range(64): - y = j / 63 - mindistances = 10000000000. - for u in range(len(indices)): - #print(a, i, x, j, y, u) - #print(indices[u][1]) - #print(indices[u][2]) - #print(f" {coefficients[u]}* np.linalg.norm({np.array((x, y))}-{np.array((indices[u][1], indices[u][2]))}") - distance = coefficients[u] * np.linalg.norm( np.array((x, y)) - np.array((indices[u][1], indices[u][2])) ) - if distance < mindistances: - mindistances = distance - uu = indices[u][0] - for k in range(4): - assert k < len(forcedlatent), k - assert i < len(forcedlatent[k]), i - assert j < len(forcedlatent[k][i]), j - assert uu < len(latent) - assert k < len(latent[uu]), k - assert i < len(latent[uu][k]), i - assert j < len(latent[uu][k][i]), j - forcedlatent[k][i][j] = float(latent[uu][k][i][j]) - #if a % 2 == 0: - # forcedlatent -= np.random.rand() * sauron - basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent - epsilon = 0.1 * (((a + .5 - len(good)) / (llambda - len(good) - 1)) ** 6) - forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) - forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent - forcedlatents += [forcedlatent] + os.environ["good"] = str(good) + os.environ["bad"] = str(bad) + coefficients = np.zeros(len(indices)) + if voronoi_in_images: + numpy_images = [np.array(image) for image in images] + image = np.array(numpy_images[0]) + for a in range(llambda): + print(f"Voronoi in the image space! {a} / {llambda}") + for i in range(len(indices)): + coefficients[i] = np.exp(2. * np.random.randn()) + # Creating a forcedlatent. + for i in range(512): + x = i / 511. + for j in range(512): + y = j / 511 + mindistances = 10000000000. + for u in range(len(indices)): + distance = coefficients[u] * np.linalg.norm( np.array((x, y)) - np.array((indices[u][2], indices[u][1])) ) + if distance < mindistances: + mindistances = distance + uu = indices[u][0] + image[i][j][:] = numpy_images[uu][i][j][:] + # Conversion before using img2latent + pil_image = Image.fromarray(image) + voronoi_name = f"voronoi{a}_iteration{iteration}.png" + pil_image.save(voronoi_name) + #timage = np.array([image]).astype(np.float32) / 255.0 + #timage = timage.transpose(0, 3, 1, 2) + #timage = torch.from_numpy(timage).to(device) + #timage = repeat(timage, '1 ... -> b ...', b=1) + #timage = 2.*timage - 1. + #forcedlatent = model.encode(timage).latent_dist.sample().cpu().detach().numpy().flatten() + #basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent + basic_new_fl = img_to_latent(voronoi_name) + basic_new_fl = 0.8 * np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent + if len(good) > 1: + print("Directly copying latent vars !!!") + forcedlatents += [4.6 * basic_new_fl] + else: + epsilon = 1.0 * (((a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2) + forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) + forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent + forcedlatents += [4.6 * forcedlatent] + else: + for a in range(llambda): + print(f"Voronoi in the latent space! {a} / {llambda}") + forcedlatent = np.zeros((4, 64, 64)) + #print(type(numpy_image)) + #print(numpy_image.shape) + #print(np.max(numpy_image)) + #print(np.min(numpy_image)) + #assert False + for i in range(len(indices)): + coefficients[i] = np.exp(2. * np.random.randn()) + for i in range(64): + x = i / 63. + for j in range(64): + y = j / 63 + mindistances = 10000000000. + for u in range(len(indices)): + #print(a, i, x, j, y, u) + #print(indices[u][1]) + #print(indices[u][2]) + #print(f" {coefficients[u]}* np.linalg.norm({np.array((x, y))}-{np.array((indices[u][1], indices[u][2]))}") + distance = coefficients[u] * np.linalg.norm( np.array((x, y)) - np.array((indices[u][2], indices[u][1])) ) + if distance < mindistances: + mindistances = distance + uu = indices[u][0] + for k in range(4): + assert k < len(forcedlatent), k + assert i < len(forcedlatent[k]), i + assert j < len(forcedlatent[k][i]), j + assert uu < len(latent) + assert k < len(latent[uu]), k + assert i < len(latent[uu][k]), i + assert j < len(latent[uu][k][i]), j + forcedlatent[k][i][j] = float(latent[uu][k][i][j]) + #if a % 2 == 0: + # forcedlatent -= np.random.rand() * sauron + basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent + if len(good) > 1: + forcedlatents += [basic_new_fl] + else: + epsilon = 0.1 * (((a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2) + forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) + forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent + forcedlatents += [forcedlatent] #for uu in range(len(latent)): # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}") From 0be4b4025e24a614a4a5ff7c9d10d7e2ce722e7f Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Tue, 4 Oct 2022 08:47:50 +0200 Subject: [PATCH 63/76] fix --- minisd.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/minisd.py b/minisd.py index e3ca9695f..bcd789ea5 100644 --- a/minisd.py +++ b/minisd.py @@ -13,6 +13,9 @@ from deep_translator import GoogleTranslator from langdetect import detect from joblib import Parallel, delayed +import torch +from PIL import Image +from RealESRGAN import RealESRGAN os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" model_id = "CompVis/stable-diffusion-v1-4" @@ -37,7 +40,7 @@ os.environ["bad"] = "[]" num_iterations = 50 gs = 7.5 -voronoi_in_images = True +voronoi_in_images = False @@ -133,7 +136,7 @@ prompt = "A muscular Jesus with and assault rifle, a cap and and a light saber." prompt = "A portrait of a cute smiling woman." prompt = "A woman with black skin, red hair, egyptian dress, yellow eyes." -prompt = "Photo of a young cute black woman." +prompt = "Photo of a red haired man with tilted head." print(f"The prompt is {prompt}") @@ -168,9 +171,6 @@ def latent_to_image(latent): image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0] return image -import torch -from PIL import Image -from RealESRGAN import RealESRGAN sr_device = torch.device('cpu') #device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu') esrmodel = RealESRGAN(sr_device, scale=4) @@ -327,7 +327,8 @@ def load_img(path): model = pipe.vae def img_to_latent(path): - init_image = 1.8 * load_img(path).to(device) + #init_image = 1.8 * load_img(path).to(device) + init_image = load_img(path).to(device) init_image = repeat(init_image, '1 ... -> b ...', b=1) forced_latent = model.encode(init_image.to(device)).latent_dist.sample() new_fl = forced_latent.cpu().detach().numpy().flatten() @@ -706,7 +707,7 @@ def img_to_latent(path): for a in range(llambda): print(f"Voronoi in the image space! {a} / {llambda}") for i in range(len(indices)): - coefficients[i] = np.exp(2. * np.random.randn()) + coefficients[i] = np.exp(np.random.randn()) # Creating a forcedlatent. for i in range(512): x = i / 511. @@ -731,15 +732,18 @@ def img_to_latent(path): #forcedlatent = model.encode(timage).latent_dist.sample().cpu().detach().numpy().flatten() #basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent basic_new_fl = img_to_latent(voronoi_name) - basic_new_fl = 0.8 * np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent + basic_new_fl = np.sqrt(len(basic_new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl + #basic_new_fl = 0.8 * np.sqrt(len(basic_new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl if len(good) > 1: print("Directly copying latent vars !!!") - forcedlatents += [4.6 * basic_new_fl] + #forcedlatents += [4.6 * basic_new_fl] + forcedlatents += [basic_new_fl] else: epsilon = 1.0 * (((a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2) forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent - forcedlatents += [4.6 * forcedlatent] + forcedlatents += [forcedlatent] + #forcedlatents += [4.6 * forcedlatent] else: for a in range(llambda): print(f"Voronoi in the latent space! {a} / {llambda}") @@ -750,7 +754,7 @@ def img_to_latent(path): #print(np.min(numpy_image)) #assert False for i in range(len(indices)): - coefficients[i] = np.exp(2. * np.random.randn()) + coefficients[i] = np.exp(np.random.randn()) for i in range(64): x = i / 63. for j in range(64): @@ -776,6 +780,7 @@ def img_to_latent(path): forcedlatent[k][i][j] = float(latent[uu][k][i][j]) #if a % 2 == 0: # forcedlatent -= np.random.rand() * sauron + forcedlatent = forcedlatent.flatten() basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent if len(good) > 1: forcedlatents += [basic_new_fl] From 1ee812c3620867f00bb63830ad03e4533048907a Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Wed, 5 Oct 2022 09:07:17 +0200 Subject: [PATCH 64/76] fix --- geneticsd.py | 730 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 730 insertions(+) create mode 100644 geneticsd.py diff --git a/geneticsd.py b/geneticsd.py new file mode 100644 index 000000000..f53c7aaa4 --- /dev/null +++ b/geneticsd.py @@ -0,0 +1,730 @@ +# A ton of imports. +import random +import os +import time +import torch +import numpy as np +import shutil +import PIL +from PIL import Image +from einops import rearrange, repeat +from torch import autocast +from diffusers import StableDiffusionPipeline +import webbrowser +from deep_translator import GoogleTranslator +from langdetect import detect +from joblib import Parallel, delayed +import torch +from PIL import Image +from RealESRGAN import RealESRGAN +import pyttsx3 +import pyfiglet +import pygame +from os import listdir +from os.path import isfile, join + +# Let's parametrize a few things. +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +model_id = "CompVis/stable-diffusion-v1-4" +device = "mps" #torch.device("mps") + +white = (255, 255, 255) +green = (0, 255, 0) +darkgreen = (0, 128, 0) +red = (255, 0, 0) +blue = (0, 0, 128) +black = (0, 0, 0) + +os.environ["skl"] = "nn" +os.environ["epsilon"] = "0.005" +os.environ["decay"] = "0." +os.environ["ngoptim"] = "DiscreteLenglerOnePlusOne" +os.environ["forcedlatent"] = "" +latent_forcing = "" +os.environ["good"] = "[]" +os.environ["bad"] = "[]" +num_iterations = 50 +gs = 7.5 +sentinel = str(random.randint(0,100000)) + "XX" + str(random.randint(0,100000)) +all_files = [] +llambda = 15 + +# Creating the voice engine. +noise = pyttsx3.init() +noise.setProperty("rate", 240) +def speak(text): + noise.say(text) + noise.runAndWait() + + +# Initialization. +all_selected = [] # List of all selected images, over all the run. +all_selected_latent = [] # The corresponding latent variables. +final_selection = [] # Selection of files during the final iteration. +forcedlatents = [] # Latent variables that we want to see soon. +forcedgs = [] # forcedgs[i] is the guidance strength that we want to see for image number i. +assert llambda < 16, "lambda < 16 for convenience in pygame." +bad = [] +five_best = [] +latent = [] +images = [] +onlyfiles = [] + +# Creating the main pipeline. +pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token="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") +pipe = pipe.to(device) + + +# A ton of prompts, for fun. +prompt = "a photo of an astronaut riding a horse on mars" +prompt = "a photo of a red panda with a hat playing table tennis" +prompt = "a photorealistic portrait of " + random.choice(["Mary Cury", "Scarlett Johansson", "Marilyn Monroe", "Poison Ivy", "Black Widow", "Medusa", "Batman", "Albert Einstein", "Louis XIV", "Tarzan"]) + random.choice([" with glasses", " with a hat", " with a cigarette", "with a scarf"]) +prompt = "a photorealistic portrait of " + random.choice(["Nelson Mandela", "Superman", "Superwoman", "Volodymyr Zelenskyy", "Tsai Ing-Wen", "Lzzy Hale", "Meg Myers"]) + random.choice([" with glasses", " with a hat", " with a cigarette", "with a scarf"]) +prompt = random.choice(["A woman with three eyes", "Meg Myers", "The rock band Ankor", "Miley Cyrus", "The man named Rahan", "A murder", "Rambo playing table tennis"]) +prompt = "Photo of a female Terminator." +prompt = random.choice([ + "Photo of Tarzan as a lawyer with a tie", + "Photo of Scarlett Johansson as a sumo-tori", + "Photo of the little mermaid as a young black girl", + "Photo of Schwarzy with tentacles", + "Photo of Meg Myers with an Egyptian dress", + "Photo of Schwarzy as a ballet dancer", + ]) + + +name = random.choice(["Mark Zuckerbeg", "Zendaya", "Yann LeCun", "Scarlett Johansson", "Superman", "Meg Myers"]) +prompt = f"Photo of {name} as a sumo-tori." + +prompt = "Full length portrait of Mark Zuckerberg as a Sumo-Tori." +prompt = "Full length portrait of Scarlett Johansson as a Sumo-Tori." +prompt = "A close up photographic portrait of a young woman with uniformly colored hair." +prompt = "Zombies raising and worshipping a flying human." +prompt = "Zombies trying to kill Meg Myers." +prompt = "Meg Myers with an Egyptian dress killing a vampire with a gun." +prompt = "Meg Myers grabbing a vampire by the scruff of the neck." +prompt = "Mark Zuckerberg chokes a vampire to death." +prompt = "Mark Zuckerberg riding an animal." +prompt = "A giant cute animal worshipped by zombies." +prompt = "Several faces." +prompt = "An armoured Yann LeCun fighting tentacles in the jungle." +prompt = "Tentacles everywhere." +prompt = "A photo of a smiling Medusa." +prompt = "Medusa." +prompt = "Meg Myers in bloody armor fending off tentacles with a sword." +prompt = "A red-haired woman with red hair. Her head is tilted." +prompt = "A bloody heavy-metal zombie with a chainsaw." +prompt = "Tentacles attacking a bloody Meg Myers in Eyptian dress. Meg Myers has a chainsaw." +prompt = "Bizarre art." +prompt = "Beautiful bizarre woman." +prompt = "Yann LeCun as the grim reaper: bizarre art." +prompt = "Un chat en sang et en armure joue de la batterie." +prompt = "Photo of a cyberpunk Mark Zuckerberg killing Cthulhu with a light saber." +prompt = "A ferocious cyborg bear." +prompt = "Photo of Mark Zuckerberg killing Cthulhu with a light saber." +prompt = "A bear with horns and blood and big teeth." +prompt = "A photo of a bear and Yoda, good friends." +prompt = "A photo of Yoda on the left, a blue octopus on the right, an explosion in the center." +prompt = "A bird is on a hippo. They fight a black and red octopus. Jungle in the background." +prompt = "A flying white owl above 4 colored pots with fire. The owl has a hat." +prompt = "A flying white owl above 4 colored pots with fire." +prompt = "Yann LeCun rides a dragon which spits fire on a cherry on a cake." +prompt = "An armored Mark Zuckerberg fighting off a monster with bloody tentacles in the jungle with a light saber." +prompt = "Cute woman, portrait, photo, red hair, green eyes, smiling." +prompt = "Photo of Tarzan as a lawyer with a tie and an octopus on his head." +prompt = "An armored bloody Yann Lecun has a lightsabar and fights a red tentacular monster." +prompt = "Photo of a giant armored insect attacking a building. The building is broken. There are flames." +prompt = "Photo of Meg Myers, on the left, in Egyptian dress, fights Cthulhu (on the right) with a light saber. They stare at each other." +prompt = "Photo of a cute red panda." +prompt = "Photo of a cute smiling white-haired woman with pink eyes." +prompt = "A muscular Jesus with and assault rifle, a cap and and a light saber." +prompt = "A portrait of a cute smiling woman." +prompt = "A woman with black skin, red hair, egyptian dress, yellow eyes." +prompt = "Photo of a red haired man with tilted head." +prompt = "A photo of Cleopatra with Egyptian Dress kissing Yoda." +prompt = "A photo of Yoda fighting Meg Myers with light sabers." +prompt = "A photo of Meg Myers, laughing, pulling Gandalf's hair." +prompt = "A photo of Meg Myers laughing and pulling Gandalf's hair. Gandalf is stooping." +prompt = "A star with flashy colors." +prompt = "Portrait of a green haired woman with blue eyes." +prompt = "Portrait of a female kung-fu master." +prompt = "In a dark cave, in the middle of computers, a geek meets the devil." +print(f"The prompt is {prompt}") + + +print(pyfiglet.figlet_format("Welcome in Genetic Stable Diffusion !")) +print(pyfiglet.figlet_format("First, let us choose the text :-)!")) + + + +print(f"Francais: Proposez un nouveau texte si vous ne voulez pas dessiner << {prompt} >>.\n") +speak("Hey!") +user_prompt = input(f"English: Enter a new prompt if you prefer something else than << {prompt} >>.\n") +if len(user_prompt) > 2: + prompt = user_prompt + +# On the fly translation. +language = detect(prompt) +english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt) +def to_native(stri): + return GoogleTranslator(source='en', target=language).translate(stri) + +def pretty_print(stri): + print(pyfiglet.figlet_format(to_native(stri))) + +print(f"{to_native('Working on')} {english_prompt}, a.k.a {prompt}.") + + +# Converting a latent var to an image. +def latent_to_image(latent): + os.environ["forcedlatent"] = str(list(latent.flatten())) #str(list(forcedlatents[k].flatten())) + with autocast("cuda"): + image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0] + os.environ["forcedlatent"] = "[]" + return image + +# Creating the super-resolution stuff. RealESRGAN is fantastic! +sr_device = torch.device('cpu') #device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu') +esrmodel = RealESRGAN(sr_device, scale=4) +esrmodel.load_weights('weights/RealESRGAN_x4.pth', download=True) +esrmodel2 = RealESRGAN(sr_device, scale=2) +esrmodel2.load_weights('weights/RealESRGAN_x2.pth', download=True) + +def singleeg(path_to_image): + image = Image.open(path_to_image).convert('RGB') + sr_device = device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Type before SR = {type(image)}") + sr_image = esrmodel.predict(image) + print(f"Type after SR = {type(sr_image)}") + output_filename = path_to_image + ".SR.png" + sr_image.save(output_filename) + return output_filename + +# A version with x2. +def singleeg2(path_to_image): + image = Image.open(path_to_image).convert('RGB') + sr_device = device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Type before SR = {type(image)}") + sr_image = esrmodel2.predict(image) + print(f"Type after SR = {type(sr_image)}") + output_filename = path_to_image + ".SR.png" + sr_image.save(output_filename) + return output_filename + + +# realESRGan applied to many files. +def eg(list_of_files): + pretty_print("Should I convert images below to high resolution ?") + print(list_of_files) + speak("Go to the text window!") + answer = input(" [y]es / [n]o ?") + if "y" in answer or "Y" in answer: + #images = Parallel(n_jobs=12)(delayed(singleeg)(image) for image in list_of_files) + #print(to_native(f"Created the super-resolution files {images}")) + for path_to_image in list_of_files: + output_filename = singleeg(path_to_image) + print(to_native(f"Created the super-resolution file {output_filename}")) + +# When we stop the run and check and propose to do super-resolution and/or animations. +def stop_all(list_of_files, list_of_latent, last_list_of_latent): + print(to_native("Your selected images and the last generation:")) + print(list_of_files) + eg(list_of_files) + pretty_print("Should we create animations ?") + answer = input(" [y]es or [n]o or [j]ust the selection on the last panel ?") + if "y" in answer or "Y" in answer or "j" in answer or "J" in answer: + assert len(list_of_files) == len(list_of_latent) + if "j" in answer or "J" in answer: + list_of_latent = last_list_of_latent + pretty_print("Let us create animations!") + for c in sorted([0.5, 0.25, 0.125, 0.0625, 0.05, 0.04,0.03125]): + for idx in range(len(list_of_files)): + images = [] + l = list_of_latent[idx].reshape(1,4,64,64) + l = np.sqrt(len(l.flatten()) / np.sum(l**2)) * l + l1 = l + c * np.random.randn(len(l.flatten())).reshape(1,4,64,64) + l1 = np.sqrt(len(l1.flatten()) / np.sum(l1**2)) * l1 + l2 = l + c * np.random.randn(len(l.flatten())).reshape(1,4,64,64) + l2 = np.sqrt(len(l2.flatten()) / np.sum(l2**2)) * l2 + num_animation_steps = 13 + index = 0 + for u in np.linspace(0., 2*3.14159 * (1-1/30), 30): + cc = np.cos(u) + ss = np.sin(u*2) + index += 1 + image = latent_to_image(l + cc * (l1 - l) + ss * (l2 - l)) + image_name = f"imgA{index}.png" + image.save(image_name) + images += [image_name] + + print(to_native(f"Base images created for perturbation={c} and file {list_of_files[idx]}")) + images = Parallel(n_jobs=10)(delayed(singleeg2)(image) for image in images) + frames = [Image.open(image) for image in images] + frame_one = frames[0] + gif_name = list_of_files[idx] + "_" + str(c) + ".gif" + frame_one.save(gif_name, format="GIF", append_images=frames, + save_all=True, duration=100, loop=0) + webbrowser.open(os.environ["PWD"] + "/" + gif_name) + + pretty_print("Should we create a meme ?") + answer = input(" [y]es or [n]o ?") + if "y" in answer or "Y" in answer: + url = 'https://imgflip.com/memegenerator' + webbrowser.open(url) + pretty_print("Good bye!") + exit() + + + + + +pretty_print("Now let us choose (if you want) an image as a start.") +image_name = input(to_native("Name of image for starting ? (enter if no start image)")) + +# activate the pygame library . +pygame.init() +X = 2000 # > 1500 = buttons +Y = 900 +scrn = pygame.display.set_mode((1700, Y + 100)) +font = pygame.font.Font('freesansbold.ttf', 22) +bigfont = pygame.font.Font('freesansbold.ttf', 44) + +def load_img(path): + image = Image.open(path).convert("RGB") + w, h = image.size + print(to_native(f"loaded input image of size ({w}, {h}) from {path}")) + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((512, 512), resample=PIL.Image.LANCZOS) + #image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.*image - 1. + +model = pipe.vae + +def img_to_latent(path): + #init_image = 1.8 * load_img(path).to(device) + init_image = load_img(path).to(device) + init_image = repeat(init_image, '1 ... -> b ...', b=1) + forced_latent = model.encode(init_image.to(device)).latent_dist.sample() + new_fl = forced_latent.cpu().detach().numpy().flatten() + new_fl = np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) + return new_fl + +def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=None): + base_init_image = load_img(image_name).to(device) + new_base_init_image = base_init_image + c = np.exp(np.random.randn()) if c is None else c + f = np.exp(-3. * np.random.rand()) if f is None else f + init_image_shape = base_init_image.cpu().numpy().shape + init_image = c * new_base_init_image + init_image = repeat(init_image, '1 ... -> b ...', b=1) + forced_latent = 1. * model.encode(init_image.to(device)).latent_dist.sample() + new_fl = forced_latent.cpu().detach().numpy().flatten() + basic_new_fl = new_fl #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl + basic_new_fl = f * np.sqrt(len(new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl + epsilon = 0.1 * np.exp(-3 * np.random.rand()) if epsilon is None else epsilon + new_fl = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(1*4*64*64) + scale = 2.8 + 3.6 * np.random.rand() if scale is None else scale + new_fl = scale * np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) + #image = latent_to_image(np.asarray(new_fl)) #eval(os.environ["forcedlatent"]))) + #image.save(f"rebuild_{f}_{scale}_{epsilon}_{c}.png") + return new_fl + +# In case the user wants to start from a given image. +if len(image_name) > 0: + pretty_print("Importing an image !") + try: + init_image = load_img(image_name).to(device) + except: + pretty_print("Try again!") + pretty_print("Loading failed!!") + image_name = input(to_native("Name of image for starting ? (enter if no start image)")) + + base_init_image = load_img(image_name).to(device) + speak("Image loaded!") + print(base_init_image.shape) + print(np.max(base_init_image.cpu().detach().numpy().flatten())) + print(np.min(base_init_image.cpu().detach().numpy().flatten())) + + forcedlatents = [] + try: + latent_file = image_name + ".latent.txt" + print(to_native(f"Trying to load latent variables in {latent_file}.")) + f = open(latent_file, "r") + print(to_native("File opened.")) + latent_str = f.read() + print("Latent string read.") + latent_found = True + for i in range(llambda): + basic_new_fl = np.asarray(eval(latent_str)) + if i > 0: + basic_new_fl = f * np.sqrt(len(new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl + epsilon = .7 * ((i-1)/(llambda-1)) #1.0 / 2**(2 + (llambda - i) / 6) + #print(f"{i} -- {i % 7} {c} {f} {epsilon}") + new_fl = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(1*4*64*64) + else: + new_fl = basic_new_fl + new_fl = 6. * np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) + forcedlatents += [new_fl] + except: + print(to_native("No latent file: guessing.")) + for i in range(llambda): + forcedlatents += [randomized_image_to_latent(image_name)] #img_to_latent(voronoi_name) + +# We start the big time consuming loop! +for iteration in range(3000): # Kind of an infinite loop. + latent = [latent[f] for f in five_best] + images = [images[f] for f in five_best] + onlyfiles = [onlyfiles[f] for f in five_best] + early_stop = [] + speak("Wait!") + final_selection = [] + for k in range(llambda): + if len(early_stop) > 0: + break + max_created_index = k + if k < len(forcedlatents): + latent_forcing = str(list(forcedlatents[k].flatten())) + print(f"We play with {latent_forcing[:20]}") + if k < len(five_best): + imp = pygame.transform.scale(pygame.image.load(onlyfiles[k]).convert(), (300, 300)) + scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) + pygame.display.flip() + continue + pygame.draw.rect(scrn, black, pygame.Rect(0, Y, 1700, Y+100)) + pygame.draw.rect(scrn, black, pygame.Rect(1500, 0, 2000, Y+100)) + text0 = bigfont.render(to_native(f'Please wait !!! {k} / {llambda}'), True, green, blue) + scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/4)) + text0 = font.render(to_native(f'Or, for an early stopping,'), True, green, blue) + scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2-Y/8)) + text0 = font.render(to_native(f'click and WAIT a bit'), True, green, blue) + scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2)) + text0 = font.render(to_native(f'... ... ... '), True, green, blue) + scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2+Y/8)) + + # Button for early stopping + text2 = font.render(to_native(f'Total: {len(all_selected)} chosen images! '), True, green, blue) + text2 = pygame.transform.rotate(text2, 90) + scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) + text2 = font.render(to_native('Click for stopping,'), True, green, blue) + text2 = pygame.transform.rotate(text2, 90) + scrn.blit(text2, (X*3/4+X/16+X/64 - X/32, Y/3)) + text2 = font.render(to_native('and get the effects.'), True, green, blue) + text2 = pygame.transform.rotate(text2, 90) + scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3)) + + pygame.display.flip() + os.environ["earlystop"] = "False" if k > len(five_best) else "True" + os.environ["epsilon"] = str(0. if k == len(five_best) else (k - len(five_best)) / llambda) + os.environ["budget"] = str(np.random.randint(400) if k > len(five_best) else 2) + os.environ["skl"] = {0: "nn", 1: "tree", 2: "logit"}[k % 3] + previous_gs = gs + if k < len(forcedgs): + gs = forcedgs[k] + image = latent_to_image(np.asarray(latent_forcing)) #eval(os.environ["forcedlatent"]))) + gs = previous_gs + + images += [image] + filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{iteration:05d}_{k:05d}.png" + image.save(filename) + onlyfiles += [filename] + imp = pygame.transform.scale(pygame.image.load(onlyfiles[-1]).convert(), (300, 300)) + scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) + pygame.display.flip() + print('\a') # beep! + str_latent = eval((os.environ["latent_sd"])) + array_latent = eval(f"np.array(str_latent).reshape(4, 64, 64)") + print(f"Debug info: array_latent sumsq/var {sum(array_latent.flatten() ** 2) / len(array_latent.flatten())}") + latent += [array_latent] + with open(filename + ".latent.txt", 'w') as f: + f.write(f"{str_latent}") + + # In case of early stopping, we stop the loop. + first_event = True + for i in pygame.event.get(): + if i.type == pygame.MOUSEBUTTONUP: + if first_event: + speak("Ok I stop!") + first_event = False + pos = pygame.mouse.get_pos() + index = 3 * (pos[0] // 300) + (pos[1] // 300) + if pos[0] > X and pos[1] > Y /3 and pos[1] < 2*Y/3: + stop_all(all_selected, all_selected_latent, final_selection) + exit() + if index <= k: + pretty_print(("You clicked for requesting an early stopping.")) + early_stop = [pos] + break + early_stop = [(1,1)] + satus = False + forcedgs = [] + + speak("Please choose!") + pretty_print("Please choose your images.") + text0 = bigfont.render(to_native(f'Choose your favorite images !!!========='), True, green, blue) + scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/4)) + text0 = font.render(to_native(f'=================================='), True, green, blue) + scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2-Y/8)) + text0 = font.render(to_native(f'=================================='), True, green, blue) + scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2)) + # Add rectangles + pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 0, X*3/4+X/16+X/32, Y/3), 2) + pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, Y/3, X*3/4+X/16+X/32, 2*Y/3), 2) + pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y), 2) + pygame.draw.rect(scrn, red, pygame.Rect(0, Y, X/2, Y+100), 2) + + # Button for loading a starting point + text1 = font.render('Manually edit an image.', True, green, blue) + text1 = pygame.transform.rotate(text1, 90) + #scrn.blit(text1, (X*3/4+X/16 - X/32, 0)) + #text1 = font.render('& latent ', True, green, blue) + #text1 = pygame.transform.rotate(text1, 90) + #scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, 0)) + + # Button for creating a meme + text2 = font.render(to_native('Click ,'), True, green, blue) + text2 = pygame.transform.rotate(text2, 90) + scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3+10)) + text2 = font.render(to_native('for finishing with effects.'), True, green, blue) + text2 = pygame.transform.rotate(text2, 90) + scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3+10)) + # Button for new generation + text3 = font.render(to_native(f"I don't want to select images"), True, green, blue) + text3 = pygame.transform.rotate(text3, 90) + scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3+10)) + text3 = font.render(to_native(f"Just rerun."), True, green, blue) + text3 = pygame.transform.rotate(text3, 90) + scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3+10)) + text4 = font.render(to_native(f"Modify parameters or text!"), True, green, blue) + scrn.blit(text4, (300, Y + 30)) + pygame.display.flip() + + for idx in range(max_created_index + 1): + # set the pygame window name + pygame.display.set_caption(prompt) + print(to_native(f"Pasting image {onlyfiles[idx]}...")) + imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300)) + scrn.blit(imp, (300 * (idx // 3), 300 * (idx % 3))) + + # paint screen one time + pygame.display.flip() + status = True + indices = [] + good = [] + five_best = [] + for i in pygame.event.get(): + if i.type == pygame.MOUSEBUTTONUP: + print(to_native(".... too early for clicking !!!!")) + + + pretty_print("Please click on your favorite elements!") + print(to_native("You might just click on one image and we will provide variations.")) + print(to_native("Or you can click on the top of an image and the bottom of another one.")) + print(to_native("Click on the << new generation >> when you're done.")) + while (status): + + # iterate over the list of Event objects + # that was returned by pygame.event.get() method. + for i in pygame.event.get(): + if hasattr(i, "type") and i.type == pygame.MOUSEBUTTONUP: + pos = pygame.mouse.get_pos() + pretty_print(f"Detected! Click at {pos}") + if pos[1] > Y: + pretty_print("Let us update parameters!") + text4 = font.render(to_native(f"ok, go to text window!"), True, green, blue) + scrn.blit(text4, (300, Y + 30)) + pygame.display.flip() + try: + num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n"))) + except: + num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n"))) + gs = float(input(to_native(f"Guidance scale ? (current = {gs})\n"))) + print(to_native(f"The current text is << {prompt} >>.")) + print(to_native("Start your answer with a symbol << + >> if this is an edit and not a new text.")) + new_prompt = str(input(to_native(f"Enter a text if you want to change from ") + prompt)) + if len(new_prompt) > 2: + if new_prompt[0] == "+": + prompt += new_prompt[1:] + else: + prompt = new_prompt + language = detect(prompt) + english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt) + pretty_print("Ok! Parameters updated.") + pretty_print("==> go back to the window!") + text4 = font.render(to_native(f"Ok! parameters changed!"), True, green, blue) + scrn.blit(text4, (300, Y + 30)) + pygame.display.flip() + elif pos[0] > 1500: # Not in the images. + if pos[1] < Y/3: + #filename = input(to_native("Filename (please provide the latent file, of the format SD*latent*.txt) ?\n")) + #status = False + #with open(filename, 'r') as f: + # latent = f.read() + #break + pretty_print("Easy! I exit now, you edit the file and you save it.") + pretty_print("Then just relaunch me and provide the text and the image.") + exit() + if pos[1] < 2*Y/3: + #onlyfiles = [f for f in listdir(".") if isfile(join(mypath, f))] + #onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] + assert len(onlyfiles) == len(latent) + assert len(all_selected) == len(all_selected_latent) + stop_all(all_selected, all_selected_latent, final_selection) # + onlyfiles, all_selected_latent + latent) + exit() + status = False + break + index = 3 * (pos[0] // 300) + (pos[1] // 300) + pygame.draw.circle(scrn, red, [pos[0], pos[1]], 13, 0) + if index <= max_created_index: + selected_filename = to_native("Selected") + onlyfiles[index] + shutil.copyfile(onlyfiles[index], selected_filename) + assert len(onlyfiles) == len(latent), f"{len(onlyfiles)} != {len(latent)}" + all_selected += [selected_filename] + all_selected_latent += [latent[index]] + final_selection += [latent[index]] + text2 = font.render(to_native(f'==> {len(all_selected)} chosen images! '), True, green, blue) + text2 = pygame.transform.rotate(text2, 90) + scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) + if index not in five_best and len(five_best) < 5: + five_best += [index] + indices += [[index, (pos[0] - (pos[0] // 300) * 300) / 300, (pos[1] - (pos[1] // 300) * 300) / 300]] + # Update the button for new generation. + pygame.draw.rect(scrn, black, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y)) + pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y), 2) + text3 = font.render(to_native(f" You have chosen {len(indices)} images:"), True, green, blue) + text3 = pygame.transform.rotate(text3, 90) + scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3)) + text3 = font.render(to_native(f" Click for new generation!"), True, green, blue) + text3 = pygame.transform.rotate(text3, 90) + scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3)) + pygame.display.flip() + #text3Rect = text3.get_rect() + #text3Rect.center = (750+750*3/4, 1000) + good += [list(latent[index].flatten())] + else: + speak("Bad click ! Click on an image.") + pretty_print("Bad click! Click on image.") + + if i.type == pygame.QUIT: + status = False + + # Covering old images with full circles. + for _ in range(123): + x = np.random.randint(1500) + y = np.random.randint(900) + pygame.draw.circle(scrn, darkgreen, + [x, y], 17, 0) + pygame.display.update() + if len(indices) == 0: + print("The user did not like anything! Rerun :-(") + continue + print(f"Clicks at {indices}") + os.environ["mu"] = str(len(indices)) + forcedlatents = [] + bad += [list(latent[u].flatten()) for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]] + #sauron = 0 * latent[0] + #for u in [u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]]: + # sauron += latent[u] + #sauron = (1 / len([u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]])) * sauron + if len(bad) > 500: + bad = bad[(len(bad) - 500):] + print(to_native(f"{len(indices)} indices are selected.")) + #print(f"indices = {indices}") + os.environ["good"] = str(good) + os.environ["bad"] = str(bad) + coefficients = np.zeros(len(indices)) + numpy_images = [np.array(image) for image in images] + for a in range(llambda): + voronoi_in_images = False #(a % 2 == 1) and len(good) > 1 + if voronoi_in_images: + image = np.array(numpy_images[0]) + print(f"Voronoi in the image space! {a} / {llambda}") + for i in range(len(indices)): + coefficients[i] = np.exp(np.random.randn()) + # Creating a forcedlatent. + for i in range(512): + x = i / 511. + for j in range(512): + y = j / 511 + mindistances = 10000000000. + for u in range(len(indices)): + distance = coefficients[u] * np.linalg.norm( np.array((x, y)) - np.array((indices[u][2], indices[u][1])) ) + if distance < mindistances: + mindistances = distance + uu = indices[u][0] + image[i][j][:] = numpy_images[uu][i][j][:] + # Conversion before using img2latent + pil_image = Image.fromarray(image) + voronoi_name = f"voronoi{a}_iteration{iteration}.png" + pil_image.save(voronoi_name) + #timage = np.array([image]).astype(np.float32) / 255.0 + #timage = timage.transpose(0, 3, 1, 2) + #timage = torch.from_numpy(timage).to(device) + #timage = repeat(timage, '1 ... -> b ...', b=1) + #timage = 2.*timage - 1. + #forcedlatent = model.encode(timage).latent_dist.sample().cpu().detach().numpy().flatten() + #basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent + basic_new_fl = randomized_image_to_latent(voronoi_name) #img_to_latent(voronoi_name) + basic_new_fl = np.sqrt(len(basic_new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl + #basic_new_fl = 0.8 * np.sqrt(len(basic_new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl + if len(good) > 1: + print("Directly copying latent vars !!!") + #forcedlatents += [4.6 * basic_new_fl] + forcedlatents += [basic_new_fl] + else: + epsilon = 1.0 * (((a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2) + forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) + forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent + forcedlatents += [forcedlatent] + #forcedlatents += [4.6 * forcedlatent] + else: + print(f"Voronoi in the latent space! {a} / {llambda}") + forcedlatent = np.zeros((4, 64, 64)) + #print(type(numpy_image)) + #print(numpy_image.shape) + #print(np.max(numpy_image)) + #print(np.min(numpy_image)) + #assert False + for i in range(len(indices)): + coefficients[i] = np.exp(np.random.randn()) + for i in range(64): + x = i / 63. + for j in range(64): + y = j / 63 + mindistances = 10000000000. + for u in range(len(indices)): + #print(a, i, x, j, y, u) + #print(indices[u][1]) + #print(indices[u][2]) + #print(f" {coefficients[u]}* np.linalg.norm({np.array((x, y))}-{np.array((indices[u][1], indices[u][2]))}") + distance = coefficients[u] * np.linalg.norm( np.array((x, y)) - np.array((indices[u][2], indices[u][1])) ) + if distance < mindistances: + mindistances = distance + uu = indices[u][0] + for k in range(4): + assert k < len(forcedlatent), k + assert i < len(forcedlatent[k]), i + assert j < len(forcedlatent[k][i]), j + assert uu < len(latent) + assert k < len(latent[uu]), k + assert i < len(latent[uu][k]), i + assert j < len(latent[uu][k][i]), j + forcedlatent[k][i][j] = float(latent[uu][k][i][j]) + #if a % 2 == 0: + # forcedlatent -= np.random.rand() * sauron + forcedlatent = forcedlatent.flatten() + basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent + if len(good) > 1 or len(forcedlatents) < len(good) + 1: + forcedlatents += [basic_new_fl] + else: + epsilon = ((0.5 * (a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2) + forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) + forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent + forcedlatents += [forcedlatent] + #for uu in range(len(latent)): + # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}") + os.environ["good"] = "[]" + os.environ["bad"] = "[]" + +pygame.quit() From 19ab206810e7dcea87a92102870e850900970f26 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Wed, 5 Oct 2022 14:35:45 +0200 Subject: [PATCH 65/76] fix --- geneticsd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/geneticsd.py b/geneticsd.py index f53c7aaa4..85ca9dea2 100644 --- a/geneticsd.py +++ b/geneticsd.py @@ -147,7 +147,7 @@ def speak(text): prompt = "A star with flashy colors." prompt = "Portrait of a green haired woman with blue eyes." prompt = "Portrait of a female kung-fu master." -prompt = "In a dark cave, in the middle of computers, a geek meets the devil." +prompt = "In a dark cave, in the middle of computers, a bearded red-haired geek with squared glasses meets the devil." print(f"The prompt is {prompt}") From 133b1e3bb9b679194e86c888349249e530409212 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Wed, 5 Oct 2022 18:22:18 +0200 Subject: [PATCH 66/76] fix --- README.md | 2 +- .../pipeline_stable_diffusion.py | 14 ++++- edit.sh | 5 +- geneticsd.py | 1 + minisd.py | 58 ++++++++++++++----- view_history.sh | 2 +- 6 files changed, 63 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index b89cfc055..41140684f 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ pip install git+https://github.com/sberbank-ai/Real-ESRGAN.git wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P weights ``` -## Then run << python minisd.py >>. +## Then run << python geneticsd.py >>. You should be asked for a prompt (just <> if you like the proposed hardcoded prompt), and then a window should be opened. ## Send feedback to [**Nevergrad Users**](https://www.facebook.com/groups/nevergradusers/)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index e8e076137..81c57e216 100644 --- a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -215,6 +215,17 @@ def __call__( latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) latents_intermediate_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) speedup = 1 +# if len(os.environ["forcedlatent"]) < 5: +# forcedlatent = np.random.randn(4*64*64).reshape(4,64,64) +# for u in range(64): +# for v in range(64): +# if (u-32)**2 + (v-32)**2 > (32*1.2)**2: +# forcedlatent[0][u][v] = 0 +# forcedlatent[1][u][v] = 0 +# forcedlatent[2][u][v] = 0 +# forcedlatent[3][u][v] = 0 +# os.environ["forcedlatent"] = str(list(forcedlatent.flatten())) + if latents is None: latents = torch.randn( latents_intermediate_shape, @@ -363,7 +374,7 @@ def loss(x): latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual - print(f"text_embeddings.shape={text_embeddings.shape}") + #print(f"text_embeddings.shape={text_embeddings.shape}") noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance @@ -380,6 +391,7 @@ def loss(x): # scale and decode the image latents with vae #os.environ["latent_sd"] = str(list(latents.flatten().cpu().detach().numpy())) latents = 1 / 0.18215 * latents + #os.environ["latent_sd"] = str(list(latents.flatten().cpu().numpy())) image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) diff --git a/edit.sh b/edit.sh index 325be12b4..50234090c 100755 --- a/edit.sh +++ b/edit.sh @@ -1,2 +1,3 @@ -vim /opt/miniconda3/envs/sd_gpu/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py -cp /opt/miniconda3/envs/sd_gpu/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py . +vim diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +#vim /opt/miniconda3/envs/sd_gpu/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +#cp /opt/miniconda3/envs/sd_gpu/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py . diff --git a/geneticsd.py b/geneticsd.py index 85ca9dea2..2ba67915a 100644 --- a/geneticsd.py +++ b/geneticsd.py @@ -148,6 +148,7 @@ def speak(text): prompt = "Portrait of a green haired woman with blue eyes." prompt = "Portrait of a female kung-fu master." prompt = "In a dark cave, in the middle of computers, a bearded red-haired geek with squared glasses meets the devil." +prompt = "Photo of the devil, with horns. There are flames in the background." print(f"The prompt is {prompt}") diff --git a/minisd.py b/minisd.py index bcd789ea5..3bab87bae 100644 --- a/minisd.py +++ b/minisd.py @@ -1,3 +1,4 @@ +assert False, "Deprecated! Use geneticsd.py instead." import random import os import time @@ -40,7 +41,6 @@ os.environ["bad"] = "[]" num_iterations = 50 gs = 7.5 -voronoi_in_images = False @@ -113,7 +113,6 @@ prompt = "Beautiful bizarre woman." prompt = "Yann LeCun as the grim reaper: bizarre art." -prompt = "A star with flashy colors." prompt = "Un chat en sang et en armure joue de la batterie." prompt = "Photo of a cyberpunk Mark Zuckerberg killing Cthulhu with a light saber." prompt = "A ferocious cyborg bear." @@ -137,6 +136,14 @@ prompt = "A portrait of a cute smiling woman." prompt = "A woman with black skin, red hair, egyptian dress, yellow eyes." prompt = "Photo of a red haired man with tilted head." +prompt = "A photo of Cleopatra with Egyptian Dress kissing Yoda." +prompt = "A photo of Yoda fighting Meg Myers with light sabers." +prompt = "A photo of Meg Myers, laughing, pulling Gandalf's hair." +prompt = "A photo of Meg Myers laughing and pulling Gandalf's hair. Gandalf is stooping." +prompt = "A star with flashy colors." +prompt = "Portrait of a green haired woman with blue eyes." +prompt = "Portrait of a female kung-fu master." +prompt = "In a dark cave, in the middle of computers, a geek meets the devil." print(f"The prompt is {prompt}") @@ -169,6 +176,7 @@ def latent_to_image(latent): os.environ["forcedlatent"] = str(list(latent.flatten())) #str(list(forcedlatents[k].flatten())) with autocast("cuda"): image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0] + os.environ["forcedlatent"] = "[]" return image @@ -335,6 +343,26 @@ def img_to_latent(path): new_fl = np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) return new_fl +def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=None): + base_init_image = load_img(image_name).to(device) + new_base_init_image = base_init_image + c = np.exp(np.random.randn()) if c is None else c + f = np.exp(-3. * np.random.rand()) if f is None else f + init_image_shape = base_init_image.cpu().numpy().shape + init_image = c * new_base_init_image + init_image = repeat(init_image, '1 ... -> b ...', b=1) + forced_latent = 1. * model.encode(init_image.to(device)).latent_dist.sample() + new_fl = forced_latent.cpu().detach().numpy().flatten() + basic_new_fl = new_fl #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl + basic_new_fl = f * np.sqrt(len(new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl + epsilon = 0.1 * np.exp(-3 * np.random.rand()) if epsilon is None else epsilon + new_fl = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(1*4*64*64) + scale = 2.8 + 3.6 * np.random.rand() if scale is None else scale + new_fl = scale * np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) + #image = latent_to_image(np.asarray(new_fl)) #eval(os.environ["forcedlatent"]))) + #image.save(f"rebuild_{f}_{scale}_{epsilon}_{c}.png") + return new_fl + if len(image_name) > 0: pretty_print("Importing an image !") try: @@ -694,17 +722,18 @@ def img_to_latent(path): #for u in [u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]]: # sauron += latent[u] #sauron = (1 / len([u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]])) * sauron - if len(bad) > 300: - bad = bad[(len(bad) - 300):] + if len(bad) > 500: + bad = bad[(len(bad) - 500):] print(to_native(f"{len(indices)} indices are selected.")) #print(f"indices = {indices}") os.environ["good"] = str(good) os.environ["bad"] = str(bad) coefficients = np.zeros(len(indices)) - if voronoi_in_images: - numpy_images = [np.array(image) for image in images] - image = np.array(numpy_images[0]) - for a in range(llambda): + numpy_images = [np.array(image) for image in images] + for a in range(llambda): + voronoi_in_images = False #(a % 2 == 1) and len(good) > 1 + if voronoi_in_images: + image = np.array(numpy_images[0]) print(f"Voronoi in the image space! {a} / {llambda}") for i in range(len(indices)): coefficients[i] = np.exp(np.random.randn()) @@ -731,7 +760,7 @@ def img_to_latent(path): #timage = 2.*timage - 1. #forcedlatent = model.encode(timage).latent_dist.sample().cpu().detach().numpy().flatten() #basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent - basic_new_fl = img_to_latent(voronoi_name) + basic_new_fl = randomized_image_to_latent(voronoi_name) #img_to_latent(voronoi_name) basic_new_fl = np.sqrt(len(basic_new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl #basic_new_fl = 0.8 * np.sqrt(len(basic_new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl if len(good) > 1: @@ -744,8 +773,7 @@ def img_to_latent(path): forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent forcedlatents += [forcedlatent] #forcedlatents += [4.6 * forcedlatent] - else: - for a in range(llambda): + else: print(f"Voronoi in the latent space! {a} / {llambda}") forcedlatent = np.zeros((4, 64, 64)) #print(type(numpy_image)) @@ -782,14 +810,16 @@ def img_to_latent(path): # forcedlatent -= np.random.rand() * sauron forcedlatent = forcedlatent.flatten() basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent - if len(good) > 1: + if len(good) > 1 or len(forcedlatents) < len(good) + 1: forcedlatents += [basic_new_fl] else: - epsilon = 0.1 * (((a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2) + epsilon = ((0.5 * (a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2) forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) - forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent + #forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent forcedlatents += [forcedlatent] #for uu in range(len(latent)): # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}") + os.environ["good"] = "[]" + os.environ["bad"] = "[]" pygame.quit() diff --git a/view_history.sh b/view_history.sh index 3f801d2e1..09f5c1618 100755 --- a/view_history.sh +++ b/view_history.sh @@ -6,5 +6,5 @@ #montage $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`_0_11.png | sort ) $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`_0_4.png | sort ) $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`_1_?.png | sort -n ) $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`_1_??.png | sort -n ) -mode concatenate -tile 5x history.png #montage $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`*.png | sort ) -mode concatenate -tile 5x history.png #open history.png -open $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`*.png | sort ) +open $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`*.png | tail -n 15 | sort ) #cp history.png zuck3.png From be15417c7043bd9ca7b5a42a0e218736c8e11bb8 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 6 Oct 2022 10:39:27 +0200 Subject: [PATCH 67/76] fix --- geneticsd.py | 27 +- minisd.py | 1653 +++++++++++++++++++++++++------------------------- 2 files changed, 847 insertions(+), 833 deletions(-) diff --git a/geneticsd.py b/geneticsd.py index 2ba67915a..f30a493ee 100644 --- a/geneticsd.py +++ b/geneticsd.py @@ -61,6 +61,7 @@ def speak(text): all_selected = [] # List of all selected images, over all the run. all_selected_latent = [] # The corresponding latent variables. final_selection = [] # Selection of files during the final iteration. +final_selection_latent = [] # Selection of files during the final iteration. forcedlatents = [] # Latent variables that we want to see soon. forcedgs = [] # forcedgs[i] is the guidance strength that we want to see for image number i. assert llambda < 16, "lambda < 16 for convenience in pygame." @@ -149,6 +150,8 @@ def speak(text): prompt = "Portrait of a female kung-fu master." prompt = "In a dark cave, in the middle of computers, a bearded red-haired geek with squared glasses meets the devil." prompt = "Photo of the devil, with horns. There are flames in the background." +prompt = "Yann LeCun fighting Pinocchio with light sabers." +prompt = "Yann LeCun attacks a triceratops with a lightsaber." print(f"The prompt is {prompt}") @@ -213,12 +216,16 @@ def singleeg2(path_to_image): # realESRGan applied to many files. -def eg(list_of_files): +def eg(list_of_files, last_list_of_files): pretty_print("Should I convert images below to high resolution ?") print(list_of_files) + print("Last iteration:") + print(last_list_of_files) speak("Go to the text window!") - answer = input(" [y]es / [n]o ?") - if "y" in answer or "Y" in answer: + answer = input(" [y]es / [n]o / [j]ust the ones in last iteration ?") + if "y" in answer or "Y" in answer or "j" in answer or "J" in answer: + if j in answer or "J" in answer: + list_of_files = last_list_of_files #images = Parallel(n_jobs=12)(delayed(singleeg)(image) for image in list_of_files) #print(to_native(f"Created the super-resolution files {images}")) for path_to_image in list_of_files: @@ -226,10 +233,10 @@ def eg(list_of_files): print(to_native(f"Created the super-resolution file {output_filename}")) # When we stop the run and check and propose to do super-resolution and/or animations. -def stop_all(list_of_files, list_of_latent, last_list_of_latent): +def stop_all(list_of_files, list_of_latent, last_list_of_files, last_list_of_latent): print(to_native("Your selected images and the last generation:")) print(list_of_files) - eg(list_of_files) + eg(list_of_files, last_list_of_files) pretty_print("Should we create animations ?") answer = input(" [y]es or [n]o or [j]ust the selection on the last panel ?") if "y" in answer or "Y" in answer or "j" in answer or "J" in answer: @@ -237,7 +244,7 @@ def stop_all(list_of_files, list_of_latent, last_list_of_latent): if "j" in answer or "J" in answer: list_of_latent = last_list_of_latent pretty_print("Let us create animations!") - for c in sorted([0.5, 0.25, 0.125, 0.0625, 0.05, 0.04,0.03125]): + for c in sorted([0.0025, 0.005, 0.01, 0.02]): for idx in range(len(list_of_files)): images = [] l = list_of_latent[idx].reshape(1,4,64,64) @@ -381,6 +388,7 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N early_stop = [] speak("Wait!") final_selection = [] + final_selection_latent = [] for k in range(llambda): if len(early_stop) > 0: break @@ -451,7 +459,7 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N pos = pygame.mouse.get_pos() index = 3 * (pos[0] // 300) + (pos[1] // 300) if pos[0] > X and pos[1] > Y /3 and pos[1] < 2*Y/3: - stop_all(all_selected, all_selected_latent, final_selection) + stop_all(all_selected, all_selected_latent, final_selection, final_selection_latent) exit() if index <= k: pretty_print(("You clicked for requesting an early stopping.")) @@ -571,7 +579,7 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N #onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] assert len(onlyfiles) == len(latent) assert len(all_selected) == len(all_selected_latent) - stop_all(all_selected, all_selected_latent, final_selection) # + onlyfiles, all_selected_latent + latent) + stop_all(all_selected, all_selected_latent, final_selection, final_selection_latent) # + onlyfiles, all_selected_latent + latent) exit() status = False break @@ -583,7 +591,8 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N assert len(onlyfiles) == len(latent), f"{len(onlyfiles)} != {len(latent)}" all_selected += [selected_filename] all_selected_latent += [latent[index]] - final_selection += [latent[index]] + final_selection += [selected_filename] + final_selection_latent += [latent[index]] text2 = font.render(to_native(f'==> {len(all_selected)} chosen images! '), True, green, blue) text2 = pygame.transform.rotate(text2, 90) scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) diff --git a/minisd.py b/minisd.py index 3bab87bae..83699367c 100644 --- a/minisd.py +++ b/minisd.py @@ -1,825 +1,830 @@ assert False, "Deprecated! Use geneticsd.py instead." -import random -import os -import time -import torch -import numpy as np -import shutil -import PIL -from PIL import Image -from einops import rearrange, repeat -from torch import autocast -from diffusers import StableDiffusionPipeline -import webbrowser -from deep_translator import GoogleTranslator -from langdetect import detect -from joblib import Parallel, delayed -import torch -from PIL import Image -from RealESRGAN import RealESRGAN - -os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" -model_id = "CompVis/stable-diffusion-v1-4" -#device = "cuda" -device = "mps" #torch.device("mps") - -white = (255, 255, 255) -green = (0, 255, 0) -darkgreen = (0, 128, 0) -red = (255, 0, 0) -blue = (0, 0, 128) -black = (0, 0, 0) - -os.environ["skl"] = "nn" -os.environ["epsilon"] = "0.005" -os.environ["decay"] = "0." -os.environ["ngoptim"] = "DiscreteLenglerOnePlusOne" -os.environ["forcedlatent"] = "" -latent_forcing = "" -#os.environ["enforcedlatent"] = "" -os.environ["good"] = "[]" -os.environ["bad"] = "[]" -num_iterations = 50 -gs = 7.5 - - - -import pyttsx3 - -noise = pyttsx3.init() -noise.setProperty("rate", 240) -noise.setProperty('voice', 'mb-us1') - -#voice = noise.getProperty('voices') -#for v in voice: -# if v.name == "Kyoko": -# noise.setProperty('voice', v.id) - - -all_selected = [] -all_selected_latent = [] -final_selection = [] -forcedlatents = [] -forcedgs = [] - - - -pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token="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") -pipe = pipe.to(device) - -prompt = "a photo of an astronaut riding a horse on mars" -prompt = "a photo of a red panda with a hat playing table tennis" -prompt = "a photorealistic portrait of " + random.choice(["Mary Cury", "Scarlett Johansson", "Marilyn Monroe", "Poison Ivy", "Black Widow", "Medusa", "Batman", "Albert Einstein", "Louis XIV", "Tarzan"]) + random.choice([" with glasses", " with a hat", " with a cigarette", "with a scarf"]) -prompt = "a photorealistic portrait of " + random.choice(["Nelson Mandela", "Superman", "Superwoman", "Volodymyr Zelenskyy", "Tsai Ing-Wen", "Lzzy Hale", "Meg Myers"]) + random.choice([" with glasses", " with a hat", " with a cigarette", "with a scarf"]) -prompt = random.choice(["A woman with three eyes", "Meg Myers", "The rock band Ankor", "Miley Cyrus", "The man named Rahan", "A murder", "Rambo playing table tennis"]) -prompt = "Photo of a female Terminator." -prompt = random.choice([ - "Photo of Tarzan as a lawyer with a tie", - "Photo of Scarlett Johansson as a sumo-tori", - "Photo of the little mermaid as a young black girl", - "Photo of Schwarzy with tentacles", - "Photo of Meg Myers with an Egyptian dress", - "Photo of Schwarzy as a ballet dancer", - ]) - - -name = random.choice(["Mark Zuckerbeg", "Zendaya", "Yann LeCun", "Scarlett Johansson", "Superman", "Meg Myers"]) -name = "Zendaya" -prompt = f"Photo of {name} as a sumo-tori." - -prompt = "Full length portrait of Mark Zuckerberg as a Sumo-Tori." -prompt = "Full length portrait of Scarlett Johansson as a Sumo-Tori." -prompt = "A close up photographic portrait of a young woman with uniformly colored hair." -prompt = "Zombies raising and worshipping a flying human." -prompt = "Zombies trying to kill Meg Myers." -prompt = "Meg Myers with an Egyptian dress killing a vampire with a gun." -prompt = "Meg Myers grabbing a vampire by the scruff of the neck." -prompt = "Mark Zuckerberg chokes a vampire to death." -prompt = "Mark Zuckerberg riding an animal." -prompt = "A giant cute animal worshipped by zombies." - - -prompt = "Several faces." - -prompt = "An armoured Yann LeCun fighting tentacles in the jungle." -prompt = "Tentacles everywhere." -prompt = "A photo of a smiling Medusa." -prompt = "Medusa." -prompt = "Meg Myers in bloody armor fending off tentacles with a sword." -prompt = "A red-haired woman with red hair. Her head is tilted." -prompt = "A bloody heavy-metal zombie with a chainsaw." -prompt = "Tentacles attacking a bloody Meg Myers in Eyptian dress. Meg Myers has a chainsaw." -prompt = "Bizarre art." - -prompt = "Beautiful bizarre woman." -prompt = "Yann LeCun as the grim reaper: bizarre art." -prompt = "Un chat en sang et en armure joue de la batterie." -prompt = "Photo of a cyberpunk Mark Zuckerberg killing Cthulhu with a light saber." -prompt = "A ferocious cyborg bear." -prompt = "Photo of Mark Zuckerberg killing Cthulhu with a light saber." -prompt = "A bear with horns and blood and big teeth." -prompt = "A photo of a bear and Yoda, good friends." -prompt = "A photo of Yoda on the left, a blue octopus on the right, an explosion in the center." -prompt = "A bird is on a hippo. They fight a black and red octopus. Jungle in the background." -prompt = "A flying white owl above 4 colored pots with fire. The owl has a hat." -prompt = "A flying white owl above 4 colored pots with fire." -prompt = "Yann LeCun rides a dragon which spits fire on a cherry on a cake." -prompt = "An armored Mark Zuckerberg fighting off a monster with bloody tentacles in the jungle with a light saber." -prompt = "Cute woman, portrait, photo, red hair, green eyes, smiling." -prompt = "Photo of Tarzan as a lawyer with a tie and an octopus on his head." -prompt = "An armored bloody Yann Lecun has a lightsabar and fights a red tentacular monster." -prompt = "Photo of a giant armored insect attacking a building. The building is broken. There are flames." -prompt = "Photo of Meg Myers, on the left, in Egyptian dress, fights Cthulhu (on the right) with a light saber. They stare at each other." -prompt = "Photo of a cute red panda." -prompt = "Photo of a cute smiling white-haired woman with pink eyes." -prompt = "A muscular Jesus with and assault rifle, a cap and and a light saber." -prompt = "A portrait of a cute smiling woman." -prompt = "A woman with black skin, red hair, egyptian dress, yellow eyes." -prompt = "Photo of a red haired man with tilted head." -prompt = "A photo of Cleopatra with Egyptian Dress kissing Yoda." -prompt = "A photo of Yoda fighting Meg Myers with light sabers." -prompt = "A photo of Meg Myers, laughing, pulling Gandalf's hair." -prompt = "A photo of Meg Myers laughing and pulling Gandalf's hair. Gandalf is stooping." -prompt = "A star with flashy colors." -prompt = "Portrait of a green haired woman with blue eyes." -prompt = "Portrait of a female kung-fu master." -prompt = "In a dark cave, in the middle of computers, a geek meets the devil." -print(f"The prompt is {prompt}") - - -import pyfiglet -print(pyfiglet.figlet_format("Welcome in Genetic Stable Diffusion !")) -print(pyfiglet.figlet_format("First, let us choose the text :-)!")) - - - -print(f"Francais: Proposez un nouveau texte si vous ne voulez pas dessiner << {prompt} >>.\n") -noise.say("Hey!") -noise.runAndWait() -user_prompt = input(f"English: Enter a new prompt if you prefer something else than << {prompt} >>.\n") -if len(user_prompt) > 2: - prompt = user_prompt - -# On the fly translation. -language = detect(prompt) -english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt) - -def to_native(stri): - return GoogleTranslator(source='en', target=language).translate(stri) - -def pretty_print(stri): - print(pyfiglet.figlet_format(to_native(stri))) - -print(f"{to_native('Working on')} {english_prompt}, a.k.a {prompt}.") - -def latent_to_image(latent): - os.environ["forcedlatent"] = str(list(latent.flatten())) #str(list(forcedlatents[k].flatten())) - with autocast("cuda"): - image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0] - os.environ["forcedlatent"] = "[]" - return image - - -sr_device = torch.device('cpu') #device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu') -esrmodel = RealESRGAN(sr_device, scale=4) -esrmodel.load_weights('weights/RealESRGAN_x4.pth', download=True) -esrmodel2 = RealESRGAN(sr_device, scale=2) -esrmodel2.load_weights('weights/RealESRGAN_x2.pth', download=True) - -def singleeg(path_to_image): - image = Image.open(path_to_image).convert('RGB') - sr_device = device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu') - print(f"Type before SR = {type(image)}") - sr_image = esrmodel.predict(image) - print(f"Type after SR = {type(sr_image)}") - output_filename = path_to_image + ".SR.png" - sr_image.save(output_filename) - return output_filename - -def singleeg2(path_to_image): - time.sleep(0.5*np.random.rand()) - image = Image.open(path_to_image).convert('RGB') - sr_device = device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu') - print(f"Type before SR = {type(image)}") - sr_image = esrmodel2.predict(image) - print(f"Type after SR = {type(sr_image)}") - output_filename = path_to_image + ".SR.png" - sr_image.save(output_filename) - return output_filename - - -def eg(list_of_files): - pretty_print("Should I convert images below to high resolution ?") - print(list_of_files) - noise.say("Go to the text window!") - noise.runAndWait() - answer = input(" [y]es / [n]o ?") - if "y" in answer or "Y" in answer: - #images = Parallel(n_jobs=12)(delayed(singleeg)(image) for image in list_of_files) - #print(to_native(f"Created the super-resolution files {images}")) - for path_to_image in list_of_files: - output_filename = singleeg(path_to_image) - print(to_native(f"Created the super-resolution file {output_filename}")) - -def stop_all(list_of_files, list_of_latent, last_list_of_latent): - print(to_native("Your selected images and the last generation:")) - print(list_of_files) - eg(list_of_files) - pretty_print("Should we create animations ?") - answer = input(" [y]es or [n]o or [j]ust the selection on the last panel ?") - if "y" in answer or "Y" in answer or "j" in answer or "J" in answer: - assert len(list_of_files) == len(list_of_latent) - if "j" in answer or "J" in answer: - list_of_latent = last_list_of_latent - pretty_print("Let us create animations!") - for c in sorted([0.5, 0.25, 0.125, 0.0625, 0.05, 0.04,0.03125]): - for idx in range(len(list_of_files)): - images = [] - l = list_of_latent[idx].reshape(1,4,64,64) - l = np.sqrt(len(l.flatten()) / np.sum(l**2)) * l - l1 = l + c * np.random.randn(len(l.flatten())).reshape(1,4,64,64) - l1 = np.sqrt(len(l1.flatten()) / np.sum(l1**2)) * l1 - l2 = l + c * np.random.randn(len(l.flatten())).reshape(1,4,64,64) - l2 = np.sqrt(len(l2.flatten()) / np.sum(l2**2)) * l2 - num_animation_steps = 13 - index = 0 - for u in np.linspace(0., 2*3.14159 * (1-1/30), 30): - cc = np.cos(u) - ss = np.sin(u*2) - index += 1 - image = latent_to_image(l + cc * (l1 - l) + ss * (l2 - l)) - image_name = f"imgA{index}.png" - image.save(image_name) - images += [image_name] - -# for u in np.linspace(0., 1., num_animation_steps): -# index += 1 -# image = latent_to_image(u*l1 + (1-u)*l) -# image_name = f"imgA{index}.png" -# image.save(image_name) -# images += [image_name] -# for u in np.linspace(0., 1., num_animation_steps): -# index += 1 -# image = latent_to_image(u*l2 + (1-u)*l1) -# image_name = f"imgB{index}.png" -# image.save(image_name) -# images += [image_name] -# for u in np.linspace(0., 1.,num_animation_steps): -# index += 1 -# image = latent_to_image(u*l + (1-u)*l2) -# image_name = f"imgC{index}.png" -# image.save(image_name) -# images += [image_name] - print(to_native(f"Base images created for perturbation={c} and file {list_of_files[idx]}")) - #images = Parallel(n_jobs=8)(delayed(process)(i) for i in range(10)) - images = Parallel(n_jobs=10)(delayed(singleeg2)(image) for image in images) - - frames = [Image.open(image) for image in images] - frame_one = frames[0] - gif_name = list_of_files[idx] + "_" + str(c) + ".gif" - frame_one.save(gif_name, format="GIF", append_images=frames, - save_all=True, duration=100, loop=0) - webbrowser.open(os.environ["PWD"] + "/" + gif_name) - - pretty_print("Should we create a meme ?") - answer = input(" [y]es or [n]o ?") - if "y" in answer or "Y" in answer: - url = 'https://imgflip.com/memegenerator' - webbrowser.open(url) - pretty_print("Good bye!") - exit() - - -import os -import pygame -from os import listdir -from os.path import isfile, join - -sentinel = str(random.randint(0,100000)) + "XX" + str(random.randint(0,100000)) - -all_files = [] - -llambda = 15 - -assert llambda < 16, "lambda < 16 for convenience in pygame." - -bad = [] -five_best = [] -latent = [] -images = [] -onlyfiles = [] - -pretty_print("Now let us choose (if you want) an image as a start.") -image_name = input(to_native("Name of image for starting ? (enter if no start image)")) - -# activate the pygame library . -pygame.init() -X = 2000 # > 1500 = buttons -Y = 900 -scrn = pygame.display.set_mode((1700, Y + 100)) -font = pygame.font.Font('freesansbold.ttf', 22) -bigfont = pygame.font.Font('freesansbold.ttf', 44) - -def load_img(path): - image = Image.open(path).convert("RGB") - w, h = image.size - print(to_native(f"loaded input image of size ({w}, {h}) from {path}")) - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((512, 512), resample=PIL.Image.LANCZOS) - #image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.*image - 1. - -model = pipe.vae - -def img_to_latent(path): - #init_image = 1.8 * load_img(path).to(device) - init_image = load_img(path).to(device) - init_image = repeat(init_image, '1 ... -> b ...', b=1) - forced_latent = model.encode(init_image.to(device)).latent_dist.sample() - new_fl = forced_latent.cpu().detach().numpy().flatten() - new_fl = np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) - return new_fl - -def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=None): - base_init_image = load_img(image_name).to(device) - new_base_init_image = base_init_image - c = np.exp(np.random.randn()) if c is None else c - f = np.exp(-3. * np.random.rand()) if f is None else f - init_image_shape = base_init_image.cpu().numpy().shape - init_image = c * new_base_init_image - init_image = repeat(init_image, '1 ... -> b ...', b=1) - forced_latent = 1. * model.encode(init_image.to(device)).latent_dist.sample() - new_fl = forced_latent.cpu().detach().numpy().flatten() - basic_new_fl = new_fl #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl - basic_new_fl = f * np.sqrt(len(new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl - epsilon = 0.1 * np.exp(-3 * np.random.rand()) if epsilon is None else epsilon - new_fl = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(1*4*64*64) - scale = 2.8 + 3.6 * np.random.rand() if scale is None else scale - new_fl = scale * np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) - #image = latent_to_image(np.asarray(new_fl)) #eval(os.environ["forcedlatent"]))) - #image.save(f"rebuild_{f}_{scale}_{epsilon}_{c}.png") - return new_fl - -if len(image_name) > 0: - pretty_print("Importing an image !") - try: - init_image = load_img(image_name).to(device) - except: - pretty_print("Try again!") - pretty_print("Loading failed!!") - image_name = input(to_native("Name of image for starting ? (enter if no start image)")) - - base_init_image = load_img(image_name).to(device) - noise.say("Image loaded") - noise.runAndWait() - print(base_init_image.shape) - print(np.max(base_init_image.cpu().detach().numpy().flatten())) - print(np.min(base_init_image.cpu().detach().numpy().flatten())) - - forcedlatents = [] - divider = 1.5 - latent_found = False - try: - latent_file = image_name + ".latent.txt" - print(to_native(f"Trying to load latent variables in {latent_file}.")) - f = open(latent_file, "r") - print(to_native("File opened.")) - latent_str = f.read() - print("Latent string read.") - latent_found = True - except: - print(to_native("No latent file: guessing.")) - for i in range(llambda): - new_base_init_image = base_init_image - if not latent_found: # In case of latent vars we need less exploration. - if (i % 7) == 1: - new_base_init_image[0,0,:,:] /= divider - if (i % 7) == 2: - new_base_init_image[0,1,:,:] /= divider - if (i % 7) == 3: - new_base_init_image[0,2,:,:] /= divider - if (i % 7) == 4: - new_base_init_image[0,0,:,:] /= divider - new_base_init_image[0,1,:,:] /= divider - if (i % 7) == 5: - new_base_init_image[0,1,:,:] /= divider - new_base_init_image[0,2,:,:] /= divider - if (i % 7) == 6: - new_base_init_image[0,0,:,:] /= divider - new_base_init_image[0,2,:,:] /= divider - - c = np.exp(np.random.randn() - 5) - f = np.exp(-3. * np.random.rand()) - init_image_shape = base_init_image.cpu().numpy().shape - if i > 0 and not latent_found: - init_image = new_base_init_image + torch.from_numpy(c * np.random.randn(np.prod(init_image_shape))).reshape(init_image_shape).float().to(device) - else: - init_image = new_base_init_image - init_image = repeat(init_image, '1 ... -> b ...', b=1) - if latent_found: - new_fl = np.asarray(eval(latent_str)) - assert len(new_fl) > 1 - else: - forced_latent = 1. * model.encode(init_image.to(device)).latent_dist.sample() - new_fl = forced_latent.cpu().detach().numpy().flatten() - basic_new_fl = new_fl #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl - #new_fl = forced_latent + (1. / 1.1**(llambda-i)) * torch.from_numpy(np.random.randn(1*4*64*64).reshape(1,4,64,64)).float().to(device) - #forcedlatents += [new_fl.cpu().detach().numpy()] - if i > 0: - #epsilon = 0.3 / 1.1**i - basic_new_fl = f * np.sqrt(len(new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl - epsilon = .7 * ((i-1)/(llambda-1)) #1.0 / 2**(2 + (llambda - i) / 6) - print(f"{i} -- {i % 7} {c} {f} {epsilon}") - # 1 -- 1 0.050020045300292804 0.0790648688521246 0.0 - new_fl = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(1*4*64*64) - else: - new_fl = basic_new_fl - new_fl = 6. * np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) - forcedlatents += [new_fl] #np.clip(new_fl, -3., 3.)] #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl] - forcedgs += [7.5] #np.random.choice([7.5, 15.0, 30.0, 60.0])] TODO - #forcedlatents += [np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl] - #print(f"{i} --> {forcedlatents[i][:10]}") - -# We start the big time consuming loop! -for iteration in range(30): - latent = [latent[f] for f in five_best] - images = [images[f] for f in five_best] - onlyfiles = [onlyfiles[f] for f in five_best] - early_stop = [] - noise.say("WAIT!") - noise.runAndWait() - final_selection = [] - for k in range(llambda): - if len(early_stop) > 0: - break - max_created_index = k - if len(forcedlatents) > 0 and k < len(forcedlatents): - #os.environ["forcedlatent"] = str(list(forcedlatents[k].flatten())) - latent_forcing = str(list(forcedlatents[k].flatten())) - print(f"We play with {latent_forcing[:20]}") - if k < len(five_best): - imp = pygame.transform.scale(pygame.image.load(onlyfiles[k]).convert(), (300, 300)) - # Using blit to copy content from one surface to other - scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) - pygame.display.flip() - continue - pygame.draw.rect(scrn, black, pygame.Rect(0, Y, 1700, Y+100)) - pygame.draw.rect(scrn, black, pygame.Rect(1500, 0, 2000, Y+100)) - text0 = bigfont.render(to_native(f'Please wait !!! {k} / {llambda}'), True, green, blue) - scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/4)) - text0 = font.render(to_native(f'Or, for an early stopping,'), True, green, blue) - scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2-Y/8)) - text0 = font.render(to_native(f'click and WAIT a bit'), True, green, blue) - scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2)) - text0 = font.render(to_native(f'... ... ... '), True, green, blue) - scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2+Y/8)) - - # Button for early stopping - text2 = font.render(to_native(f'Total: {len(all_selected)} chosen images! '), True, green, blue) - text2 = pygame.transform.rotate(text2, 90) - scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) - text2 = font.render(to_native('Click for stopping,'), True, green, blue) - text2 = pygame.transform.rotate(text2, 90) - scrn.blit(text2, (X*3/4+X/16+X/64 - X/32, Y/3)) - text2 = font.render(to_native('and get the effects.'), True, green, blue) - text2 = pygame.transform.rotate(text2, 90) - scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3)) - - pygame.display.flip() - os.environ["earlystop"] = "False" if k > len(five_best) else "True" - os.environ["epsilon"] = str(0. if k == len(five_best) else (k - len(five_best)) / llambda) - os.environ["budget"] = str(np.random.randint(400) if k > len(five_best) else 2) - os.environ["skl"] = {0: "nn", 1: "tree", 2: "logit"}[k % 3] - #enforcedlatent = os.environ.get("enforcedlatent", "") - #if len(enforcedlatent) > 2: - # os.environ["forcedlatent"] = enforcedlatent - # os.environ["enforcedlatent"] = "" - #with autocast("cuda"): - # image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0] - previous_gs = gs - if k < len(forcedgs): - gs = forcedgs[k] - image = latent_to_image(np.asarray(latent_forcing)) #eval(os.environ["forcedlatent"]))) - gs = previous_gs - - images += [image] - filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{iteration:05d}_{k:05d}.png" - image.save(filename) - onlyfiles += [filename] - imp = pygame.transform.scale(pygame.image.load(onlyfiles[-1]).convert(), (300, 300)) - # Using blit to copy content from one surface to other - scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) - pygame.display.flip() - #noise.say("Dong") - #noise.runAndWait() - print('\a') - str_latent = eval((os.environ["latent_sd"])) - array_latent = eval(f"np.array(str_latent).reshape(4, 64, 64)") - print(f"Debug info: array_latent sumsq/var {sum(array_latent.flatten() ** 2) / len(array_latent.flatten())}") - latent += [array_latent] - with open(filename + ".latent.txt", 'w') as f: - f.write(f"{str_latent}") - # In case of early stopping. - first_event = True - for i in pygame.event.get(): - if i.type == pygame.MOUSEBUTTONUP: - if first_event: - noise.say("Ok I stop") - noise.runAndWait() - first_event = False - pos = pygame.mouse.get_pos() - index = 3 * (pos[0] // 300) + (pos[1] // 300) - if pos[0] > X and pos[1] > Y /3 and pos[1] < 2*Y/3: - stop_all(all_selected, all_selected_latent, final_selection) - exit() - if index <= k: - pretty_print(("You clicked for requesting an early stopping.")) - early_stop = [pos] - break - early_stop = [(1,1)] - satus = False - forcedgs = [] - # Stop the forcing from disk! - #os.environ["enforcedlatent"] = "" - # importing required library - - #mypath = "./" - #onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))] - #onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] - #print() - - # create the display surface object - # of specific dimension..e(X, Y). - noise.say("Ok I'm ready! Choose") - noise.runAndWait() - pretty_print("Please choose your images.") - text0 = bigfont.render(to_native(f'Choose your favorite images !!!========='), True, green, blue) - scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/4)) - text0 = font.render(to_native(f'=================================='), True, green, blue) - scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2-Y/8)) - text0 = font.render(to_native(f'=================================='), True, green, blue) - scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2)) - # Add rectangles - pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 0, X*3/4+X/16+X/32, Y/3), 2) - pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, Y/3, X*3/4+X/16+X/32, 2*Y/3), 2) - pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y), 2) - pygame.draw.rect(scrn, red, pygame.Rect(0, Y, X/2, Y+100), 2) - - # Button for loading a starting point - text1 = font.render('Manually edit an image.', True, green, blue) - text1 = pygame.transform.rotate(text1, 90) - #scrn.blit(text1, (X*3/4+X/16 - X/32, 0)) - #text1 = font.render('& latent ', True, green, blue) - #text1 = pygame.transform.rotate(text1, 90) - #scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, 0)) - - # Button for creating a meme - text2 = font.render(to_native('Click ,'), True, green, blue) - text2 = pygame.transform.rotate(text2, 90) - scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3+10)) - text2 = font.render(to_native('for finishing with effects.'), True, green, blue) - text2 = pygame.transform.rotate(text2, 90) - scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3+10)) - # Button for new generation - text3 = font.render(to_native(f"I don't want to select images"), True, green, blue) - text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3+10)) - text3 = font.render(to_native(f"Just rerun."), True, green, blue) - text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3+10)) - text4 = font.render(to_native(f"Modify parameters or text!"), True, green, blue) - scrn.blit(text4, (300, Y + 30)) - pygame.display.flip() - - for idx in range(max_created_index + 1): - # set the pygame window name - pygame.display.set_caption(prompt) - print(to_native(f"Pasting image {onlyfiles[idx]}...")) - imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300)) - scrn.blit(imp, (300 * (idx // 3), 300 * (idx % 3))) - - # paint screen one time - pygame.display.flip() - status = True - indices = [] - good = [] - five_best = [] - for i in pygame.event.get(): - if i.type == pygame.MOUSEBUTTONUP: - print(to_native(".... too early for clicking !!!!")) - - - pretty_print("Please click on your favorite elements!") - print(to_native("You might just click on one image and we will provide variations.")) - print(to_native("Or you can click on the top of an image and the bottom of another one.")) - print(to_native("Click on the << new generation >> when you're done.")) - while (status): - - # iterate over the list of Event objects - # that was returned by pygame.event.get() method. - for i in pygame.event.get(): - if hasattr(i, "type") and i.type == pygame.MOUSEBUTTONUP: - pos = pygame.mouse.get_pos() - pretty_print(f"Detected! Click at {pos}") - if pos[1] > Y: - pretty_print("Let us update parameters!") - text4 = font.render(to_native(f"ok, go to text window!"), True, green, blue) - scrn.blit(text4, (300, Y + 30)) - pygame.display.flip() - try: - num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n"))) - except: - num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n"))) - gs = float(input(to_native(f"Guidance scale ? (current = {gs})\n"))) - print(to_native(f"The current text is << {prompt} >>.")) - print(to_native("Start your answer with a symbol << + >> if this is an edit and not a new text.")) - new_prompt = str(input(to_native(f"Enter a text if you want to change from ") + prompt)) - if len(new_prompt) > 2: - if new_prompt[0] == "+": - prompt += new_prompt[1:] - else: - prompt = new_prompt - language = detect(prompt) - english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt) - pretty_print("Ok! Parameters updated.") - pretty_print("==> go back to the window!") - text4 = font.render(to_native(f"Ok! parameters changed!"), True, green, blue) - scrn.blit(text4, (300, Y + 30)) - pygame.display.flip() - elif pos[0] > 1500: # Not in the images. - if pos[1] < Y/3: - #filename = input(to_native("Filename (please provide the latent file, of the format SD*latent*.txt) ?\n")) - #status = False - #with open(filename, 'r') as f: - # latent = f.read() - #break - pretty_print("Easy! I exit now, you edit the file and you save it.") - pretty_print("Then just relaunch me and provide the text and the image.") - exit() - if pos[1] < 2*Y/3: - #onlyfiles = [f for f in listdir(".") if isfile(join(mypath, f))] - #onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] - assert len(onlyfiles) == len(latent) - assert len(all_selected) == len(all_selected_latent) - stop_all(all_selected, all_selected_latent, final_selection) # + onlyfiles, all_selected_latent + latent) - exit() - status = False - break - index = 3 * (pos[0] // 300) + (pos[1] // 300) - pygame.draw.circle(scrn, red, [pos[0], pos[1]], 13, 0) - if index <= max_created_index: - selected_filename = to_native("Selected") + onlyfiles[index] - shutil.copyfile(onlyfiles[index], selected_filename) - assert len(onlyfiles) == len(latent), f"{len(onlyfiles)} != {len(latent)}" - all_selected += [selected_filename] - all_selected_latent += [latent[index]] - final_selection += [latent[index]] - text2 = font.render(to_native(f'==> {len(all_selected)} chosen images! '), True, green, blue) - text2 = pygame.transform.rotate(text2, 90) - scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) - if index not in five_best and len(five_best) < 5: - five_best += [index] - indices += [[index, (pos[0] - (pos[0] // 300) * 300) / 300, (pos[1] - (pos[1] // 300) * 300) / 300]] - # Update the button for new generation. - pygame.draw.rect(scrn, black, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y)) - pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y), 2) - text3 = font.render(to_native(f" You have chosen {len(indices)} images:"), True, green, blue) - text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3)) - text3 = font.render(to_native(f" Click for new generation!"), True, green, blue) - text3 = pygame.transform.rotate(text3, 90) - scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3)) - pygame.display.flip() - #text3Rect = text3.get_rect() - #text3Rect.center = (750+750*3/4, 1000) - good += [list(latent[index].flatten())] - else: - noise.say("Bad click! Click on image.") - noise.runAndWait() - pretty_print("Bad click! Click on image.") - - if i.type == pygame.QUIT: - status = False - - # Covering old images with full circles. - for _ in range(123): - x = np.random.randint(1500) - y = np.random.randint(900) - pygame.draw.circle(scrn, darkgreen, - [x, y], 17, 0) - pygame.display.update() - if len(indices) == 0: - print("The user did not like anything! Rerun :-(") - continue - print(f"Clicks at {indices}") - os.environ["mu"] = str(len(indices)) - forcedlatents = [] - bad += [list(latent[u].flatten()) for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]] - #sauron = 0 * latent[0] - #for u in [u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]]: - # sauron += latent[u] - #sauron = (1 / len([u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]])) * sauron - if len(bad) > 500: - bad = bad[(len(bad) - 500):] - print(to_native(f"{len(indices)} indices are selected.")) - #print(f"indices = {indices}") - os.environ["good"] = str(good) - os.environ["bad"] = str(bad) - coefficients = np.zeros(len(indices)) - numpy_images = [np.array(image) for image in images] - for a in range(llambda): - voronoi_in_images = False #(a % 2 == 1) and len(good) > 1 - if voronoi_in_images: - image = np.array(numpy_images[0]) - print(f"Voronoi in the image space! {a} / {llambda}") - for i in range(len(indices)): - coefficients[i] = np.exp(np.random.randn()) - # Creating a forcedlatent. - for i in range(512): - x = i / 511. - for j in range(512): - y = j / 511 - mindistances = 10000000000. - for u in range(len(indices)): - distance = coefficients[u] * np.linalg.norm( np.array((x, y)) - np.array((indices[u][2], indices[u][1])) ) - if distance < mindistances: - mindistances = distance - uu = indices[u][0] - image[i][j][:] = numpy_images[uu][i][j][:] - # Conversion before using img2latent - pil_image = Image.fromarray(image) - voronoi_name = f"voronoi{a}_iteration{iteration}.png" - pil_image.save(voronoi_name) - #timage = np.array([image]).astype(np.float32) / 255.0 - #timage = timage.transpose(0, 3, 1, 2) - #timage = torch.from_numpy(timage).to(device) - #timage = repeat(timage, '1 ... -> b ...', b=1) - #timage = 2.*timage - 1. - #forcedlatent = model.encode(timage).latent_dist.sample().cpu().detach().numpy().flatten() - #basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent - basic_new_fl = randomized_image_to_latent(voronoi_name) #img_to_latent(voronoi_name) - basic_new_fl = np.sqrt(len(basic_new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl - #basic_new_fl = 0.8 * np.sqrt(len(basic_new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl - if len(good) > 1: - print("Directly copying latent vars !!!") - #forcedlatents += [4.6 * basic_new_fl] - forcedlatents += [basic_new_fl] - else: - epsilon = 1.0 * (((a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2) - forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) - forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent - forcedlatents += [forcedlatent] - #forcedlatents += [4.6 * forcedlatent] - else: - print(f"Voronoi in the latent space! {a} / {llambda}") - forcedlatent = np.zeros((4, 64, 64)) - #print(type(numpy_image)) - #print(numpy_image.shape) - #print(np.max(numpy_image)) - #print(np.min(numpy_image)) - #assert False - for i in range(len(indices)): - coefficients[i] = np.exp(np.random.randn()) - for i in range(64): - x = i / 63. - for j in range(64): - y = j / 63 - mindistances = 10000000000. - for u in range(len(indices)): - #print(a, i, x, j, y, u) - #print(indices[u][1]) - #print(indices[u][2]) - #print(f" {coefficients[u]}* np.linalg.norm({np.array((x, y))}-{np.array((indices[u][1], indices[u][2]))}") - distance = coefficients[u] * np.linalg.norm( np.array((x, y)) - np.array((indices[u][2], indices[u][1])) ) - if distance < mindistances: - mindistances = distance - uu = indices[u][0] - for k in range(4): - assert k < len(forcedlatent), k - assert i < len(forcedlatent[k]), i - assert j < len(forcedlatent[k][i]), j - assert uu < len(latent) - assert k < len(latent[uu]), k - assert i < len(latent[uu][k]), i - assert j < len(latent[uu][k][i]), j - forcedlatent[k][i][j] = float(latent[uu][k][i][j]) - #if a % 2 == 0: - # forcedlatent -= np.random.rand() * sauron - forcedlatent = forcedlatent.flatten() - basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent - if len(good) > 1 or len(forcedlatents) < len(good) + 1: - forcedlatents += [basic_new_fl] - else: - epsilon = ((0.5 * (a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2) - forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) - #forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent - forcedlatents += [forcedlatent] - #for uu in range(len(latent)): - # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}") - os.environ["good"] = "[]" - os.environ["bad"] = "[]" - -pygame.quit() +############### DEPRECATED: see geneticsd.py import random +############### DEPRECATED: see geneticsd.py import os +############### DEPRECATED: see geneticsd.py import time +############### DEPRECATED: see geneticsd.py import torch +############### DEPRECATED: see geneticsd.py import numpy as np +############### DEPRECATED: see geneticsd.py import shutil +############### DEPRECATED: see geneticsd.py import PIL +############### DEPRECATED: see geneticsd.py from PIL import Image +############### DEPRECATED: see geneticsd.py from einops import rearrange, repeat +############### DEPRECATED: see geneticsd.py from torch import autocast +############### DEPRECATED: see geneticsd.py from diffusers import StableDiffusionPipeline +############### DEPRECATED: see geneticsd.py import webbrowser +############### DEPRECATED: see geneticsd.py from deep_translator import GoogleTranslator +############### DEPRECATED: see geneticsd.py from langdetect import detect +############### DEPRECATED: see geneticsd.py from joblib import Parallel, delayed +############### DEPRECATED: see geneticsd.py import torch +############### DEPRECATED: see geneticsd.py from PIL import Image +############### DEPRECATED: see geneticsd.py from RealESRGAN import RealESRGAN +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +############### DEPRECATED: see geneticsd.py model_id = "CompVis/stable-diffusion-v1-4" +############### DEPRECATED: see geneticsd.py #device = "cuda" +############### DEPRECATED: see geneticsd.py device = "mps" #torch.device("mps") +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py white = (255, 255, 255) +############### DEPRECATED: see geneticsd.py green = (0, 255, 0) +############### DEPRECATED: see geneticsd.py darkgreen = (0, 128, 0) +############### DEPRECATED: see geneticsd.py red = (255, 0, 0) +############### DEPRECATED: see geneticsd.py blue = (0, 0, 128) +############### DEPRECATED: see geneticsd.py black = (0, 0, 0) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py os.environ["skl"] = "nn" +############### DEPRECATED: see geneticsd.py os.environ["epsilon"] = "0.005" +############### DEPRECATED: see geneticsd.py os.environ["decay"] = "0." +############### DEPRECATED: see geneticsd.py os.environ["ngoptim"] = "DiscreteLenglerOnePlusOne" +############### DEPRECATED: see geneticsd.py os.environ["forcedlatent"] = "" +############### DEPRECATED: see geneticsd.py latent_forcing = "" +############### DEPRECATED: see geneticsd.py #os.environ["enforcedlatent"] = "" +############### DEPRECATED: see geneticsd.py os.environ["good"] = "[]" +############### DEPRECATED: see geneticsd.py os.environ["bad"] = "[]" +############### DEPRECATED: see geneticsd.py num_iterations = 50 +############### DEPRECATED: see geneticsd.py gs = 7.5 +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py import pyttsx3 +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py noise = pyttsx3.init() +############### DEPRECATED: see geneticsd.py noise.setProperty("rate", 240) +############### DEPRECATED: see geneticsd.py noise.setProperty('voice', 'mb-us1') +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py #voice = noise.getProperty('voices') +############### DEPRECATED: see geneticsd.py #for v in voice: +############### DEPRECATED: see geneticsd.py # if v.name == "Kyoko": +############### DEPRECATED: see geneticsd.py # noise.setProperty('voice', v.id) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py all_selected = [] +############### DEPRECATED: see geneticsd.py all_selected_latent = [] +############### DEPRECATED: see geneticsd.py final_selection = [] +############### DEPRECATED: see geneticsd.py final_selection_latent = [] +############### DEPRECATED: see geneticsd.py forcedlatents = [] +############### DEPRECATED: see geneticsd.py forcedgs = [] +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token="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") +############### DEPRECATED: see geneticsd.py pipe = pipe.to(device) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py prompt = "a photo of an astronaut riding a horse on mars" +############### DEPRECATED: see geneticsd.py prompt = "a photo of a red panda with a hat playing table tennis" +############### DEPRECATED: see geneticsd.py prompt = "a photorealistic portrait of " + random.choice(["Mary Cury", "Scarlett Johansson", "Marilyn Monroe", "Poison Ivy", "Black Widow", "Medusa", "Batman", "Albert Einstein", "Louis XIV", "Tarzan"]) + random.choice([" with glasses", " with a hat", " with a cigarette", "with a scarf"]) +############### DEPRECATED: see geneticsd.py prompt = "a photorealistic portrait of " + random.choice(["Nelson Mandela", "Superman", "Superwoman", "Volodymyr Zelenskyy", "Tsai Ing-Wen", "Lzzy Hale", "Meg Myers"]) + random.choice([" with glasses", " with a hat", " with a cigarette", "with a scarf"]) +############### DEPRECATED: see geneticsd.py prompt = random.choice(["A woman with three eyes", "Meg Myers", "The rock band Ankor", "Miley Cyrus", "The man named Rahan", "A murder", "Rambo playing table tennis"]) +############### DEPRECATED: see geneticsd.py prompt = "Photo of a female Terminator." +############### DEPRECATED: see geneticsd.py prompt = random.choice([ +############### DEPRECATED: see geneticsd.py "Photo of Tarzan as a lawyer with a tie", +############### DEPRECATED: see geneticsd.py "Photo of Scarlett Johansson as a sumo-tori", +############### DEPRECATED: see geneticsd.py "Photo of the little mermaid as a young black girl", +############### DEPRECATED: see geneticsd.py "Photo of Schwarzy with tentacles", +############### DEPRECATED: see geneticsd.py "Photo of Meg Myers with an Egyptian dress", +############### DEPRECATED: see geneticsd.py "Photo of Schwarzy as a ballet dancer", +############### DEPRECATED: see geneticsd.py ]) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py name = random.choice(["Mark Zuckerbeg", "Zendaya", "Yann LeCun", "Scarlett Johansson", "Superman", "Meg Myers"]) +############### DEPRECATED: see geneticsd.py name = "Zendaya" +############### DEPRECATED: see geneticsd.py prompt = f"Photo of {name} as a sumo-tori." +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py prompt = "Full length portrait of Mark Zuckerberg as a Sumo-Tori." +############### DEPRECATED: see geneticsd.py prompt = "Full length portrait of Scarlett Johansson as a Sumo-Tori." +############### DEPRECATED: see geneticsd.py prompt = "A close up photographic portrait of a young woman with uniformly colored hair." +############### DEPRECATED: see geneticsd.py prompt = "Zombies raising and worshipping a flying human." +############### DEPRECATED: see geneticsd.py prompt = "Zombies trying to kill Meg Myers." +############### DEPRECATED: see geneticsd.py prompt = "Meg Myers with an Egyptian dress killing a vampire with a gun." +############### DEPRECATED: see geneticsd.py prompt = "Meg Myers grabbing a vampire by the scruff of the neck." +############### DEPRECATED: see geneticsd.py prompt = "Mark Zuckerberg chokes a vampire to death." +############### DEPRECATED: see geneticsd.py prompt = "Mark Zuckerberg riding an animal." +############### DEPRECATED: see geneticsd.py prompt = "A giant cute animal worshipped by zombies." +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py prompt = "Several faces." +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py prompt = "An armoured Yann LeCun fighting tentacles in the jungle." +############### DEPRECATED: see geneticsd.py prompt = "Tentacles everywhere." +############### DEPRECATED: see geneticsd.py prompt = "A photo of a smiling Medusa." +############### DEPRECATED: see geneticsd.py prompt = "Medusa." +############### DEPRECATED: see geneticsd.py prompt = "Meg Myers in bloody armor fending off tentacles with a sword." +############### DEPRECATED: see geneticsd.py prompt = "A red-haired woman with red hair. Her head is tilted." +############### DEPRECATED: see geneticsd.py prompt = "A bloody heavy-metal zombie with a chainsaw." +############### DEPRECATED: see geneticsd.py prompt = "Tentacles attacking a bloody Meg Myers in Eyptian dress. Meg Myers has a chainsaw." +############### DEPRECATED: see geneticsd.py prompt = "Bizarre art." +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py prompt = "Beautiful bizarre woman." +############### DEPRECATED: see geneticsd.py prompt = "Yann LeCun as the grim reaper: bizarre art." +############### DEPRECATED: see geneticsd.py prompt = "Un chat en sang et en armure joue de la batterie." +############### DEPRECATED: see geneticsd.py prompt = "Photo of a cyberpunk Mark Zuckerberg killing Cthulhu with a light saber." +############### DEPRECATED: see geneticsd.py prompt = "A ferocious cyborg bear." +############### DEPRECATED: see geneticsd.py prompt = "Photo of Mark Zuckerberg killing Cthulhu with a light saber." +############### DEPRECATED: see geneticsd.py prompt = "A bear with horns and blood and big teeth." +############### DEPRECATED: see geneticsd.py prompt = "A photo of a bear and Yoda, good friends." +############### DEPRECATED: see geneticsd.py prompt = "A photo of Yoda on the left, a blue octopus on the right, an explosion in the center." +############### DEPRECATED: see geneticsd.py prompt = "A bird is on a hippo. They fight a black and red octopus. Jungle in the background." +############### DEPRECATED: see geneticsd.py prompt = "A flying white owl above 4 colored pots with fire. The owl has a hat." +############### DEPRECATED: see geneticsd.py prompt = "A flying white owl above 4 colored pots with fire." +############### DEPRECATED: see geneticsd.py prompt = "Yann LeCun rides a dragon which spits fire on a cherry on a cake." +############### DEPRECATED: see geneticsd.py prompt = "An armored Mark Zuckerberg fighting off a monster with bloody tentacles in the jungle with a light saber." +############### DEPRECATED: see geneticsd.py prompt = "Cute woman, portrait, photo, red hair, green eyes, smiling." +############### DEPRECATED: see geneticsd.py prompt = "Photo of Tarzan as a lawyer with a tie and an octopus on his head." +############### DEPRECATED: see geneticsd.py prompt = "An armored bloody Yann Lecun has a lightsabar and fights a red tentacular monster." +############### DEPRECATED: see geneticsd.py prompt = "Photo of a giant armored insect attacking a building. The building is broken. There are flames." +############### DEPRECATED: see geneticsd.py prompt = "Photo of Meg Myers, on the left, in Egyptian dress, fights Cthulhu (on the right) with a light saber. They stare at each other." +############### DEPRECATED: see geneticsd.py prompt = "Photo of a cute red panda." +############### DEPRECATED: see geneticsd.py prompt = "Photo of a cute smiling white-haired woman with pink eyes." +############### DEPRECATED: see geneticsd.py prompt = "A muscular Jesus with and assault rifle, a cap and and a light saber." +############### DEPRECATED: see geneticsd.py prompt = "A portrait of a cute smiling woman." +############### DEPRECATED: see geneticsd.py prompt = "A woman with black skin, red hair, egyptian dress, yellow eyes." +############### DEPRECATED: see geneticsd.py prompt = "Photo of a red haired man with tilted head." +############### DEPRECATED: see geneticsd.py prompt = "A photo of Cleopatra with Egyptian Dress kissing Yoda." +############### DEPRECATED: see geneticsd.py prompt = "A photo of Yoda fighting Meg Myers with light sabers." +############### DEPRECATED: see geneticsd.py prompt = "A photo of Meg Myers, laughing, pulling Gandalf's hair." +############### DEPRECATED: see geneticsd.py prompt = "A photo of Meg Myers laughing and pulling Gandalf's hair. Gandalf is stooping." +############### DEPRECATED: see geneticsd.py prompt = "A star with flashy colors." +############### DEPRECATED: see geneticsd.py prompt = "Portrait of a green haired woman with blue eyes." +############### DEPRECATED: see geneticsd.py prompt = "Portrait of a female kung-fu master." +############### DEPRECATED: see geneticsd.py prompt = "In a dark cave, in the middle of computers, a geek meets the devil." +############### DEPRECATED: see geneticsd.py print(f"The prompt is {prompt}") +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py import pyfiglet +############### DEPRECATED: see geneticsd.py print(pyfiglet.figlet_format("Welcome in Genetic Stable Diffusion !")) +############### DEPRECATED: see geneticsd.py print(pyfiglet.figlet_format("First, let us choose the text :-)!")) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py print(f"Francais: Proposez un nouveau texte si vous ne voulez pas dessiner << {prompt} >>.\n") +############### DEPRECATED: see geneticsd.py noise.say("Hey!") +############### DEPRECATED: see geneticsd.py noise.runAndWait() +############### DEPRECATED: see geneticsd.py user_prompt = input(f"English: Enter a new prompt if you prefer something else than << {prompt} >>.\n") +############### DEPRECATED: see geneticsd.py if len(user_prompt) > 2: +############### DEPRECATED: see geneticsd.py prompt = user_prompt +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py # On the fly translation. +############### DEPRECATED: see geneticsd.py language = detect(prompt) +############### DEPRECATED: see geneticsd.py english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py def to_native(stri): +############### DEPRECATED: see geneticsd.py return GoogleTranslator(source='en', target=language).translate(stri) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py def pretty_print(stri): +############### DEPRECATED: see geneticsd.py print(pyfiglet.figlet_format(to_native(stri))) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py print(f"{to_native('Working on')} {english_prompt}, a.k.a {prompt}.") +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py def latent_to_image(latent): +############### DEPRECATED: see geneticsd.py os.environ["forcedlatent"] = str(list(latent.flatten())) #str(list(forcedlatents[k].flatten())) +############### DEPRECATED: see geneticsd.py with autocast("cuda"): +############### DEPRECATED: see geneticsd.py image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0] +############### DEPRECATED: see geneticsd.py os.environ["forcedlatent"] = "[]" +############### DEPRECATED: see geneticsd.py return image +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py sr_device = torch.device('cpu') #device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu') +############### DEPRECATED: see geneticsd.py esrmodel = RealESRGAN(sr_device, scale=4) +############### DEPRECATED: see geneticsd.py esrmodel.load_weights('weights/RealESRGAN_x4.pth', download=True) +############### DEPRECATED: see geneticsd.py esrmodel2 = RealESRGAN(sr_device, scale=2) +############### DEPRECATED: see geneticsd.py esrmodel2.load_weights('weights/RealESRGAN_x2.pth', download=True) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py def singleeg(path_to_image): +############### DEPRECATED: see geneticsd.py image = Image.open(path_to_image).convert('RGB') +############### DEPRECATED: see geneticsd.py sr_device = device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu') +############### DEPRECATED: see geneticsd.py print(f"Type before SR = {type(image)}") +############### DEPRECATED: see geneticsd.py sr_image = esrmodel.predict(image) +############### DEPRECATED: see geneticsd.py print(f"Type after SR = {type(sr_image)}") +############### DEPRECATED: see geneticsd.py output_filename = path_to_image + ".SR.png" +############### DEPRECATED: see geneticsd.py sr_image.save(output_filename) +############### DEPRECATED: see geneticsd.py return output_filename +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py def singleeg2(path_to_image): +############### DEPRECATED: see geneticsd.py time.sleep(0.5*np.random.rand()) +############### DEPRECATED: see geneticsd.py image = Image.open(path_to_image).convert('RGB') +############### DEPRECATED: see geneticsd.py sr_device = device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu') +############### DEPRECATED: see geneticsd.py print(f"Type before SR = {type(image)}") +############### DEPRECATED: see geneticsd.py sr_image = esrmodel2.predict(image) +############### DEPRECATED: see geneticsd.py print(f"Type after SR = {type(sr_image)}") +############### DEPRECATED: see geneticsd.py output_filename = path_to_image + ".SR.png" +############### DEPRECATED: see geneticsd.py sr_image.save(output_filename) +############### DEPRECATED: see geneticsd.py return output_filename +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py def eg(list_of_files, last_list_of_files): +############### DEPRECATED: see geneticsd.py pretty_print("Should I convert images below to high resolution ?") +############### DEPRECATED: see geneticsd.py print(list_of_files) +############### DEPRECATED: see geneticsd.py noise.say("Go to the text window!") +############### DEPRECATED: see geneticsd.py noise.runAndWait() +############### DEPRECATED: see geneticsd.py answer = input(" [y]es / [n]o / [j]ust the last batch of {len(last_list_of_files)} images ?") +############### DEPRECATED: see geneticsd.py if "y" in answer or "Y" in answer or "j" in answer or "J" in answer: +############### DEPRECATED: see geneticsd.py if j in answer or "J" in answer: +############### DEPRECATED: see geneticsd.py list_of_files = last_list_of_files +############### DEPRECATED: see geneticsd.py #images = Parallel(n_jobs=12)(delayed(singleeg)(image) for image in list_of_files) +############### DEPRECATED: see geneticsd.py #print(to_native(f"Created the super-resolution files {images}")) +############### DEPRECATED: see geneticsd.py for path_to_image in list_of_files: +############### DEPRECATED: see geneticsd.py output_filename = singleeg(path_to_image) +############### DEPRECATED: see geneticsd.py print(to_native(f"Created the super-resolution file {output_filename}")) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py def stop_all(list_of_files, list_of_latent, last_list_of_files, last_list_of_latent): +############### DEPRECATED: see geneticsd.py print(to_native("Your selected images and the last generation:")) +############### DEPRECATED: see geneticsd.py print(list_of_files) +############### DEPRECATED: see geneticsd.py eg(list_of_files, last_list_of_files) +############### DEPRECATED: see geneticsd.py pretty_print("Should we create animations ?") +############### DEPRECATED: see geneticsd.py answer = input(" [y]es or [n]o or [j]ust the selection on the last panel ?") +############### DEPRECATED: see geneticsd.py if "y" in answer or "Y" in answer or "j" in answer or "J" in answer: +############### DEPRECATED: see geneticsd.py assert len(list_of_files) == len(list_of_latent) +############### DEPRECATED: see geneticsd.py if "j" in answer or "J" in answer: +############### DEPRECATED: see geneticsd.py list_of_latent = last_list_of_latent +############### DEPRECATED: see geneticsd.py pretty_print("Let us create animations!") +############### DEPRECATED: see geneticsd.py for c in sorted([0.05, 0.04,0.03,0.02,0.01]): +############### DEPRECATED: see geneticsd.py for idx in range(len(list_of_files)): +############### DEPRECATED: see geneticsd.py images = [] +############### DEPRECATED: see geneticsd.py l = list_of_latent[idx].reshape(1,4,64,64) +############### DEPRECATED: see geneticsd.py l = np.sqrt(len(l.flatten()) / np.sum(l**2)) * l +############### DEPRECATED: see geneticsd.py l1 = l + c * np.random.randn(len(l.flatten())).reshape(1,4,64,64) +############### DEPRECATED: see geneticsd.py l1 = np.sqrt(len(l1.flatten()) / np.sum(l1**2)) * l1 +############### DEPRECATED: see geneticsd.py l2 = l + c * np.random.randn(len(l.flatten())).reshape(1,4,64,64) +############### DEPRECATED: see geneticsd.py l2 = np.sqrt(len(l2.flatten()) / np.sum(l2**2)) * l2 +############### DEPRECATED: see geneticsd.py num_animation_steps = 13 +############### DEPRECATED: see geneticsd.py index = 0 +############### DEPRECATED: see geneticsd.py for u in np.linspace(0., 2*3.14159 * (1-1/30), 30): +############### DEPRECATED: see geneticsd.py cc = np.cos(u) +############### DEPRECATED: see geneticsd.py ss = np.sin(u*2) +############### DEPRECATED: see geneticsd.py index += 1 +############### DEPRECATED: see geneticsd.py image = latent_to_image(l + cc * (l1 - l) + ss * (l2 - l)) +############### DEPRECATED: see geneticsd.py image_name = f"imgA{index}.png" +############### DEPRECATED: see geneticsd.py image.save(image_name) +############### DEPRECATED: see geneticsd.py images += [image_name] +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py # for u in np.linspace(0., 1., num_animation_steps): +############### DEPRECATED: see geneticsd.py # index += 1 +############### DEPRECATED: see geneticsd.py # image = latent_to_image(u*l1 + (1-u)*l) +############### DEPRECATED: see geneticsd.py # image_name = f"imgA{index}.png" +############### DEPRECATED: see geneticsd.py # image.save(image_name) +############### DEPRECATED: see geneticsd.py # images += [image_name] +############### DEPRECATED: see geneticsd.py # for u in np.linspace(0., 1., num_animation_steps): +############### DEPRECATED: see geneticsd.py # index += 1 +############### DEPRECATED: see geneticsd.py # image = latent_to_image(u*l2 + (1-u)*l1) +############### DEPRECATED: see geneticsd.py # image_name = f"imgB{index}.png" +############### DEPRECATED: see geneticsd.py # image.save(image_name) +############### DEPRECATED: see geneticsd.py # images += [image_name] +############### DEPRECATED: see geneticsd.py # for u in np.linspace(0., 1.,num_animation_steps): +############### DEPRECATED: see geneticsd.py # index += 1 +############### DEPRECATED: see geneticsd.py # image = latent_to_image(u*l + (1-u)*l2) +############### DEPRECATED: see geneticsd.py # image_name = f"imgC{index}.png" +############### DEPRECATED: see geneticsd.py # image.save(image_name) +############### DEPRECATED: see geneticsd.py # images += [image_name] +############### DEPRECATED: see geneticsd.py print(to_native(f"Base images created for perturbation={c} and file {list_of_files[idx]}")) +############### DEPRECATED: see geneticsd.py #images = Parallel(n_jobs=8)(delayed(process)(i) for i in range(10)) +############### DEPRECATED: see geneticsd.py images = Parallel(n_jobs=10)(delayed(singleeg2)(image) for image in images) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py frames = [Image.open(image) for image in images] +############### DEPRECATED: see geneticsd.py frame_one = frames[0] +############### DEPRECATED: see geneticsd.py gif_name = list_of_files[idx] + "_" + str(c) + ".gif" +############### DEPRECATED: see geneticsd.py frame_one.save(gif_name, format="GIF", append_images=frames, +############### DEPRECATED: see geneticsd.py save_all=True, duration=100, loop=0) +############### DEPRECATED: see geneticsd.py webbrowser.open(os.environ["PWD"] + "/" + gif_name) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py pretty_print("Should we create a meme ?") +############### DEPRECATED: see geneticsd.py answer = input(" [y]es or [n]o ?") +############### DEPRECATED: see geneticsd.py if "y" in answer or "Y" in answer: +############### DEPRECATED: see geneticsd.py url = 'https://imgflip.com/memegenerator' +############### DEPRECATED: see geneticsd.py webbrowser.open(url) +############### DEPRECATED: see geneticsd.py pretty_print("Good bye!") +############### DEPRECATED: see geneticsd.py exit() +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py import os +############### DEPRECATED: see geneticsd.py import pygame +############### DEPRECATED: see geneticsd.py from os import listdir +############### DEPRECATED: see geneticsd.py from os.path import isfile, join +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py sentinel = str(random.randint(0,100000)) + "XX" + str(random.randint(0,100000)) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py all_files = [] +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py llambda = 15 +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py assert llambda < 16, "lambda < 16 for convenience in pygame." +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py bad = [] +############### DEPRECATED: see geneticsd.py five_best = [] +############### DEPRECATED: see geneticsd.py latent = [] +############### DEPRECATED: see geneticsd.py images = [] +############### DEPRECATED: see geneticsd.py onlyfiles = [] +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py pretty_print("Now let us choose (if you want) an image as a start.") +############### DEPRECATED: see geneticsd.py image_name = input(to_native("Name of image for starting ? (enter if no start image)")) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py # activate the pygame library . +############### DEPRECATED: see geneticsd.py pygame.init() +############### DEPRECATED: see geneticsd.py X = 2000 # > 1500 = buttons +############### DEPRECATED: see geneticsd.py Y = 900 +############### DEPRECATED: see geneticsd.py scrn = pygame.display.set_mode((1700, Y + 100)) +############### DEPRECATED: see geneticsd.py font = pygame.font.Font('freesansbold.ttf', 22) +############### DEPRECATED: see geneticsd.py bigfont = pygame.font.Font('freesansbold.ttf', 44) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py def load_img(path): +############### DEPRECATED: see geneticsd.py image = Image.open(path).convert("RGB") +############### DEPRECATED: see geneticsd.py w, h = image.size +############### DEPRECATED: see geneticsd.py print(to_native(f"loaded input image of size ({w}, {h}) from {path}")) +############### DEPRECATED: see geneticsd.py w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 +############### DEPRECATED: see geneticsd.py image = image.resize((512, 512), resample=PIL.Image.LANCZOS) +############### DEPRECATED: see geneticsd.py #image = image.resize((w, h), resample=PIL.Image.LANCZOS) +############### DEPRECATED: see geneticsd.py image = np.array(image).astype(np.float32) / 255.0 +############### DEPRECATED: see geneticsd.py image = image[None].transpose(0, 3, 1, 2) +############### DEPRECATED: see geneticsd.py image = torch.from_numpy(image) +############### DEPRECATED: see geneticsd.py return 2.*image - 1. +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py model = pipe.vae +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py def img_to_latent(path): +############### DEPRECATED: see geneticsd.py #init_image = 1.8 * load_img(path).to(device) +############### DEPRECATED: see geneticsd.py init_image = load_img(path).to(device) +############### DEPRECATED: see geneticsd.py init_image = repeat(init_image, '1 ... -> b ...', b=1) +############### DEPRECATED: see geneticsd.py forced_latent = model.encode(init_image.to(device)).latent_dist.sample() +############### DEPRECATED: see geneticsd.py new_fl = forced_latent.cpu().detach().numpy().flatten() +############### DEPRECATED: see geneticsd.py new_fl = np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) +############### DEPRECATED: see geneticsd.py return new_fl +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=None): +############### DEPRECATED: see geneticsd.py base_init_image = load_img(image_name).to(device) +############### DEPRECATED: see geneticsd.py new_base_init_image = base_init_image +############### DEPRECATED: see geneticsd.py c = np.exp(np.random.randn()) if c is None else c +############### DEPRECATED: see geneticsd.py f = np.exp(-3. * np.random.rand()) if f is None else f +############### DEPRECATED: see geneticsd.py init_image_shape = base_init_image.cpu().numpy().shape +############### DEPRECATED: see geneticsd.py init_image = c * new_base_init_image +############### DEPRECATED: see geneticsd.py init_image = repeat(init_image, '1 ... -> b ...', b=1) +############### DEPRECATED: see geneticsd.py forced_latent = 1. * model.encode(init_image.to(device)).latent_dist.sample() +############### DEPRECATED: see geneticsd.py new_fl = forced_latent.cpu().detach().numpy().flatten() +############### DEPRECATED: see geneticsd.py basic_new_fl = new_fl #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl +############### DEPRECATED: see geneticsd.py basic_new_fl = f * np.sqrt(len(new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl +############### DEPRECATED: see geneticsd.py epsilon = 0.1 * np.exp(-3 * np.random.rand()) if epsilon is None else epsilon +############### DEPRECATED: see geneticsd.py new_fl = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(1*4*64*64) +############### DEPRECATED: see geneticsd.py scale = 2.8 + 3.6 * np.random.rand() if scale is None else scale +############### DEPRECATED: see geneticsd.py new_fl = scale * np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) +############### DEPRECATED: see geneticsd.py #image = latent_to_image(np.asarray(new_fl)) #eval(os.environ["forcedlatent"]))) +############### DEPRECATED: see geneticsd.py #image.save(f"rebuild_{f}_{scale}_{epsilon}_{c}.png") +############### DEPRECATED: see geneticsd.py return new_fl +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py if len(image_name) > 0: +############### DEPRECATED: see geneticsd.py pretty_print("Importing an image !") +############### DEPRECATED: see geneticsd.py try: +############### DEPRECATED: see geneticsd.py init_image = load_img(image_name).to(device) +############### DEPRECATED: see geneticsd.py except: +############### DEPRECATED: see geneticsd.py pretty_print("Try again!") +############### DEPRECATED: see geneticsd.py pretty_print("Loading failed!!") +############### DEPRECATED: see geneticsd.py image_name = input(to_native("Name of image for starting ? (enter if no start image)")) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py base_init_image = load_img(image_name).to(device) +############### DEPRECATED: see geneticsd.py noise.say("Image loaded") +############### DEPRECATED: see geneticsd.py noise.runAndWait() +############### DEPRECATED: see geneticsd.py print(base_init_image.shape) +############### DEPRECATED: see geneticsd.py print(np.max(base_init_image.cpu().detach().numpy().flatten())) +############### DEPRECATED: see geneticsd.py print(np.min(base_init_image.cpu().detach().numpy().flatten())) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py forcedlatents = [] +############### DEPRECATED: see geneticsd.py divider = 1.5 +############### DEPRECATED: see geneticsd.py latent_found = False +############### DEPRECATED: see geneticsd.py try: +############### DEPRECATED: see geneticsd.py latent_file = image_name + ".latent.txt" +############### DEPRECATED: see geneticsd.py print(to_native(f"Trying to load latent variables in {latent_file}.")) +############### DEPRECATED: see geneticsd.py f = open(latent_file, "r") +############### DEPRECATED: see geneticsd.py print(to_native("File opened.")) +############### DEPRECATED: see geneticsd.py latent_str = f.read() +############### DEPRECATED: see geneticsd.py print("Latent string read.") +############### DEPRECATED: see geneticsd.py latent_found = True +############### DEPRECATED: see geneticsd.py except: +############### DEPRECATED: see geneticsd.py print(to_native("No latent file: guessing.")) +############### DEPRECATED: see geneticsd.py for i in range(llambda): +############### DEPRECATED: see geneticsd.py new_base_init_image = base_init_image +############### DEPRECATED: see geneticsd.py if not latent_found: # In case of latent vars we need less exploration. +############### DEPRECATED: see geneticsd.py if (i % 7) == 1: +############### DEPRECATED: see geneticsd.py new_base_init_image[0,0,:,:] /= divider +############### DEPRECATED: see geneticsd.py if (i % 7) == 2: +############### DEPRECATED: see geneticsd.py new_base_init_image[0,1,:,:] /= divider +############### DEPRECATED: see geneticsd.py if (i % 7) == 3: +############### DEPRECATED: see geneticsd.py new_base_init_image[0,2,:,:] /= divider +############### DEPRECATED: see geneticsd.py if (i % 7) == 4: +############### DEPRECATED: see geneticsd.py new_base_init_image[0,0,:,:] /= divider +############### DEPRECATED: see geneticsd.py new_base_init_image[0,1,:,:] /= divider +############### DEPRECATED: see geneticsd.py if (i % 7) == 5: +############### DEPRECATED: see geneticsd.py new_base_init_image[0,1,:,:] /= divider +############### DEPRECATED: see geneticsd.py new_base_init_image[0,2,:,:] /= divider +############### DEPRECATED: see geneticsd.py if (i % 7) == 6: +############### DEPRECATED: see geneticsd.py new_base_init_image[0,0,:,:] /= divider +############### DEPRECATED: see geneticsd.py new_base_init_image[0,2,:,:] /= divider +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py c = np.exp(np.random.randn() - 5) +############### DEPRECATED: see geneticsd.py f = np.exp(-3. * np.random.rand()) +############### DEPRECATED: see geneticsd.py init_image_shape = base_init_image.cpu().numpy().shape +############### DEPRECATED: see geneticsd.py if i > 0 and not latent_found: +############### DEPRECATED: see geneticsd.py init_image = new_base_init_image + torch.from_numpy(c * np.random.randn(np.prod(init_image_shape))).reshape(init_image_shape).float().to(device) +############### DEPRECATED: see geneticsd.py else: +############### DEPRECATED: see geneticsd.py init_image = new_base_init_image +############### DEPRECATED: see geneticsd.py init_image = repeat(init_image, '1 ... -> b ...', b=1) +############### DEPRECATED: see geneticsd.py if latent_found: +############### DEPRECATED: see geneticsd.py new_fl = np.asarray(eval(latent_str)) +############### DEPRECATED: see geneticsd.py assert len(new_fl) > 1 +############### DEPRECATED: see geneticsd.py else: +############### DEPRECATED: see geneticsd.py forced_latent = 1. * model.encode(init_image.to(device)).latent_dist.sample() +############### DEPRECATED: see geneticsd.py new_fl = forced_latent.cpu().detach().numpy().flatten() +############### DEPRECATED: see geneticsd.py basic_new_fl = new_fl #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl +############### DEPRECATED: see geneticsd.py #new_fl = forced_latent + (1. / 1.1**(llambda-i)) * torch.from_numpy(np.random.randn(1*4*64*64).reshape(1,4,64,64)).float().to(device) +############### DEPRECATED: see geneticsd.py #forcedlatents += [new_fl.cpu().detach().numpy()] +############### DEPRECATED: see geneticsd.py if i > 0: +############### DEPRECATED: see geneticsd.py #epsilon = 0.3 / 1.1**i +############### DEPRECATED: see geneticsd.py basic_new_fl = f * np.sqrt(len(new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl +############### DEPRECATED: see geneticsd.py epsilon = .7 * ((i-1)/(llambda-1)) #1.0 / 2**(2 + (llambda - i) / 6) +############### DEPRECATED: see geneticsd.py print(f"{i} -- {i % 7} {c} {f} {epsilon}") +############### DEPRECATED: see geneticsd.py # 1 -- 1 0.050020045300292804 0.0790648688521246 0.0 +############### DEPRECATED: see geneticsd.py new_fl = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(1*4*64*64) +############### DEPRECATED: see geneticsd.py else: +############### DEPRECATED: see geneticsd.py new_fl = basic_new_fl +############### DEPRECATED: see geneticsd.py new_fl = 6. * np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) +############### DEPRECATED: see geneticsd.py forcedlatents += [new_fl] #np.clip(new_fl, -3., 3.)] #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl] +############### DEPRECATED: see geneticsd.py forcedgs += [7.5] #np.random.choice([7.5, 15.0, 30.0, 60.0])] TODO +############### DEPRECATED: see geneticsd.py #forcedlatents += [np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl] +############### DEPRECATED: see geneticsd.py #print(f"{i} --> {forcedlatents[i][:10]}") +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py # We start the big time consuming loop! +############### DEPRECATED: see geneticsd.py for iteration in range(30): +############### DEPRECATED: see geneticsd.py latent = [latent[f] for f in five_best] +############### DEPRECATED: see geneticsd.py images = [images[f] for f in five_best] +############### DEPRECATED: see geneticsd.py onlyfiles = [onlyfiles[f] for f in five_best] +############### DEPRECATED: see geneticsd.py early_stop = [] +############### DEPRECATED: see geneticsd.py noise.say("WAIT!") +############### DEPRECATED: see geneticsd.py noise.runAndWait() +############### DEPRECATED: see geneticsd.py final_selection = [] +############### DEPRECATED: see geneticsd.py final_selection_latent = [] +############### DEPRECATED: see geneticsd.py for k in range(llambda): +############### DEPRECATED: see geneticsd.py if len(early_stop) > 0: +############### DEPRECATED: see geneticsd.py break +############### DEPRECATED: see geneticsd.py max_created_index = k +############### DEPRECATED: see geneticsd.py if len(forcedlatents) > 0 and k < len(forcedlatents): +############### DEPRECATED: see geneticsd.py #os.environ["forcedlatent"] = str(list(forcedlatents[k].flatten())) +############### DEPRECATED: see geneticsd.py latent_forcing = str(list(forcedlatents[k].flatten())) +############### DEPRECATED: see geneticsd.py print(f"We play with {latent_forcing[:20]}") +############### DEPRECATED: see geneticsd.py if k < len(five_best): +############### DEPRECATED: see geneticsd.py imp = pygame.transform.scale(pygame.image.load(onlyfiles[k]).convert(), (300, 300)) +############### DEPRECATED: see geneticsd.py # Using blit to copy content from one surface to other +############### DEPRECATED: see geneticsd.py scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) +############### DEPRECATED: see geneticsd.py pygame.display.flip() +############### DEPRECATED: see geneticsd.py continue +############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, black, pygame.Rect(0, Y, 1700, Y+100)) +############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, black, pygame.Rect(1500, 0, 2000, Y+100)) +############### DEPRECATED: see geneticsd.py text0 = bigfont.render(to_native(f'Please wait !!! {k} / {llambda}'), True, green, blue) +############### DEPRECATED: see geneticsd.py scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/4)) +############### DEPRECATED: see geneticsd.py text0 = font.render(to_native(f'Or, for an early stopping,'), True, green, blue) +############### DEPRECATED: see geneticsd.py scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2-Y/8)) +############### DEPRECATED: see geneticsd.py text0 = font.render(to_native(f'click and WAIT a bit'), True, green, blue) +############### DEPRECATED: see geneticsd.py scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2)) +############### DEPRECATED: see geneticsd.py text0 = font.render(to_native(f'... ... ... '), True, green, blue) +############### DEPRECATED: see geneticsd.py scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2+Y/8)) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py # Button for early stopping +############### DEPRECATED: see geneticsd.py text2 = font.render(to_native(f'Total: {len(all_selected)} chosen images! '), True, green, blue) +############### DEPRECATED: see geneticsd.py text2 = pygame.transform.rotate(text2, 90) +############### DEPRECATED: see geneticsd.py scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) +############### DEPRECATED: see geneticsd.py text2 = font.render(to_native('Click for stopping,'), True, green, blue) +############### DEPRECATED: see geneticsd.py text2 = pygame.transform.rotate(text2, 90) +############### DEPRECATED: see geneticsd.py scrn.blit(text2, (X*3/4+X/16+X/64 - X/32, Y/3)) +############### DEPRECATED: see geneticsd.py text2 = font.render(to_native('and get the effects.'), True, green, blue) +############### DEPRECATED: see geneticsd.py text2 = pygame.transform.rotate(text2, 90) +############### DEPRECATED: see geneticsd.py scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3)) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py pygame.display.flip() +############### DEPRECATED: see geneticsd.py os.environ["earlystop"] = "False" if k > len(five_best) else "True" +############### DEPRECATED: see geneticsd.py os.environ["epsilon"] = str(0. if k == len(five_best) else (k - len(five_best)) / llambda) +############### DEPRECATED: see geneticsd.py os.environ["budget"] = str(np.random.randint(400) if k > len(five_best) else 2) +############### DEPRECATED: see geneticsd.py os.environ["skl"] = {0: "nn", 1: "tree", 2: "logit"}[k % 3] +############### DEPRECATED: see geneticsd.py #enforcedlatent = os.environ.get("enforcedlatent", "") +############### DEPRECATED: see geneticsd.py #if len(enforcedlatent) > 2: +############### DEPRECATED: see geneticsd.py # os.environ["forcedlatent"] = enforcedlatent +############### DEPRECATED: see geneticsd.py # os.environ["enforcedlatent"] = "" +############### DEPRECATED: see geneticsd.py #with autocast("cuda"): +############### DEPRECATED: see geneticsd.py # image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0] +############### DEPRECATED: see geneticsd.py previous_gs = gs +############### DEPRECATED: see geneticsd.py if k < len(forcedgs): +############### DEPRECATED: see geneticsd.py gs = forcedgs[k] +############### DEPRECATED: see geneticsd.py image = latent_to_image(np.asarray(latent_forcing)) #eval(os.environ["forcedlatent"]))) +############### DEPRECATED: see geneticsd.py gs = previous_gs +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py images += [image] +############### DEPRECATED: see geneticsd.py filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{iteration:05d}_{k:05d}.png" +############### DEPRECATED: see geneticsd.py image.save(filename) +############### DEPRECATED: see geneticsd.py onlyfiles += [filename] +############### DEPRECATED: see geneticsd.py imp = pygame.transform.scale(pygame.image.load(onlyfiles[-1]).convert(), (300, 300)) +############### DEPRECATED: see geneticsd.py # Using blit to copy content from one surface to other +############### DEPRECATED: see geneticsd.py scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) +############### DEPRECATED: see geneticsd.py pygame.display.flip() +############### DEPRECATED: see geneticsd.py #noise.say("Dong") +############### DEPRECATED: see geneticsd.py #noise.runAndWait() +############### DEPRECATED: see geneticsd.py print('\a') +############### DEPRECATED: see geneticsd.py str_latent = eval((os.environ["latent_sd"])) +############### DEPRECATED: see geneticsd.py array_latent = eval(f"np.array(str_latent).reshape(4, 64, 64)") +############### DEPRECATED: see geneticsd.py print(f"Debug info: array_latent sumsq/var {sum(array_latent.flatten() ** 2) / len(array_latent.flatten())}") +############### DEPRECATED: see geneticsd.py latent += [array_latent] +############### DEPRECATED: see geneticsd.py with open(filename + ".latent.txt", 'w') as f: +############### DEPRECATED: see geneticsd.py f.write(f"{str_latent}") +############### DEPRECATED: see geneticsd.py # In case of early stopping. +############### DEPRECATED: see geneticsd.py first_event = True +############### DEPRECATED: see geneticsd.py for i in pygame.event.get(): +############### DEPRECATED: see geneticsd.py if i.type == pygame.MOUSEBUTTONUP: +############### DEPRECATED: see geneticsd.py if first_event: +############### DEPRECATED: see geneticsd.py noise.say("Ok I stop") +############### DEPRECATED: see geneticsd.py noise.runAndWait() +############### DEPRECATED: see geneticsd.py first_event = False +############### DEPRECATED: see geneticsd.py pos = pygame.mouse.get_pos() +############### DEPRECATED: see geneticsd.py index = 3 * (pos[0] // 300) + (pos[1] // 300) +############### DEPRECATED: see geneticsd.py if pos[0] > X and pos[1] > Y /3 and pos[1] < 2*Y/3: +############### DEPRECATED: see geneticsd.py stop_all(all_selected, all_selected_latent, final_selection, final_selection_latent) +############### DEPRECATED: see geneticsd.py exit() +############### DEPRECATED: see geneticsd.py if index <= k: +############### DEPRECATED: see geneticsd.py pretty_print(("You clicked for requesting an early stopping.")) +############### DEPRECATED: see geneticsd.py early_stop = [pos] +############### DEPRECATED: see geneticsd.py break +############### DEPRECATED: see geneticsd.py early_stop = [(1,1)] +############### DEPRECATED: see geneticsd.py satus = False +############### DEPRECATED: see geneticsd.py forcedgs = [] +############### DEPRECATED: see geneticsd.py # Stop the forcing from disk! +############### DEPRECATED: see geneticsd.py #os.environ["enforcedlatent"] = "" +############### DEPRECATED: see geneticsd.py # importing required library +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py #mypath = "./" +############### DEPRECATED: see geneticsd.py #onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))] +############### DEPRECATED: see geneticsd.py #onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] +############### DEPRECATED: see geneticsd.py #print() +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py # create the display surface object +############### DEPRECATED: see geneticsd.py # of specific dimension..e(X, Y). +############### DEPRECATED: see geneticsd.py noise.say("Ok I'm ready! Choose") +############### DEPRECATED: see geneticsd.py noise.runAndWait() +############### DEPRECATED: see geneticsd.py pretty_print("Please choose your images.") +############### DEPRECATED: see geneticsd.py text0 = bigfont.render(to_native(f'Choose your favorite images !!!========='), True, green, blue) +############### DEPRECATED: see geneticsd.py scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/4)) +############### DEPRECATED: see geneticsd.py text0 = font.render(to_native(f'=================================='), True, green, blue) +############### DEPRECATED: see geneticsd.py scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2-Y/8)) +############### DEPRECATED: see geneticsd.py text0 = font.render(to_native(f'=================================='), True, green, blue) +############### DEPRECATED: see geneticsd.py scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2)) +############### DEPRECATED: see geneticsd.py # Add rectangles +############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 0, X*3/4+X/16+X/32, Y/3), 2) +############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, Y/3, X*3/4+X/16+X/32, 2*Y/3), 2) +############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y), 2) +############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, red, pygame.Rect(0, Y, X/2, Y+100), 2) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py # Button for loading a starting point +############### DEPRECATED: see geneticsd.py text1 = font.render('Manually edit an image.', True, green, blue) +############### DEPRECATED: see geneticsd.py text1 = pygame.transform.rotate(text1, 90) +############### DEPRECATED: see geneticsd.py #scrn.blit(text1, (X*3/4+X/16 - X/32, 0)) +############### DEPRECATED: see geneticsd.py #text1 = font.render('& latent ', True, green, blue) +############### DEPRECATED: see geneticsd.py #text1 = pygame.transform.rotate(text1, 90) +############### DEPRECATED: see geneticsd.py #scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, 0)) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py # Button for creating a meme +############### DEPRECATED: see geneticsd.py text2 = font.render(to_native('Click ,'), True, green, blue) +############### DEPRECATED: see geneticsd.py text2 = pygame.transform.rotate(text2, 90) +############### DEPRECATED: see geneticsd.py scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3+10)) +############### DEPRECATED: see geneticsd.py text2 = font.render(to_native('for finishing with effects.'), True, green, blue) +############### DEPRECATED: see geneticsd.py text2 = pygame.transform.rotate(text2, 90) +############### DEPRECATED: see geneticsd.py scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3+10)) +############### DEPRECATED: see geneticsd.py # Button for new generation +############### DEPRECATED: see geneticsd.py text3 = font.render(to_native(f"I don't want to select images"), True, green, blue) +############### DEPRECATED: see geneticsd.py text3 = pygame.transform.rotate(text3, 90) +############### DEPRECATED: see geneticsd.py scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3+10)) +############### DEPRECATED: see geneticsd.py text3 = font.render(to_native(f"Just rerun."), True, green, blue) +############### DEPRECATED: see geneticsd.py text3 = pygame.transform.rotate(text3, 90) +############### DEPRECATED: see geneticsd.py scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3+10)) +############### DEPRECATED: see geneticsd.py text4 = font.render(to_native(f"Modify parameters or text!"), True, green, blue) +############### DEPRECATED: see geneticsd.py scrn.blit(text4, (300, Y + 30)) +############### DEPRECATED: see geneticsd.py pygame.display.flip() +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py for idx in range(max_created_index + 1): +############### DEPRECATED: see geneticsd.py # set the pygame window name +############### DEPRECATED: see geneticsd.py pygame.display.set_caption(prompt) +############### DEPRECATED: see geneticsd.py print(to_native(f"Pasting image {onlyfiles[idx]}...")) +############### DEPRECATED: see geneticsd.py imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300)) +############### DEPRECATED: see geneticsd.py scrn.blit(imp, (300 * (idx // 3), 300 * (idx % 3))) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py # paint screen one time +############### DEPRECATED: see geneticsd.py pygame.display.flip() +############### DEPRECATED: see geneticsd.py status = True +############### DEPRECATED: see geneticsd.py indices = [] +############### DEPRECATED: see geneticsd.py good = [] +############### DEPRECATED: see geneticsd.py five_best = [] +############### DEPRECATED: see geneticsd.py for i in pygame.event.get(): +############### DEPRECATED: see geneticsd.py if i.type == pygame.MOUSEBUTTONUP: +############### DEPRECATED: see geneticsd.py print(to_native(".... too early for clicking !!!!")) +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py pretty_print("Please click on your favorite elements!") +############### DEPRECATED: see geneticsd.py print(to_native("You might just click on one image and we will provide variations.")) +############### DEPRECATED: see geneticsd.py print(to_native("Or you can click on the top of an image and the bottom of another one.")) +############### DEPRECATED: see geneticsd.py print(to_native("Click on the << new generation >> when you're done.")) +############### DEPRECATED: see geneticsd.py while (status): +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py # iterate over the list of Event objects +############### DEPRECATED: see geneticsd.py # that was returned by pygame.event.get() method. +############### DEPRECATED: see geneticsd.py for i in pygame.event.get(): +############### DEPRECATED: see geneticsd.py if hasattr(i, "type") and i.type == pygame.MOUSEBUTTONUP: +############### DEPRECATED: see geneticsd.py pos = pygame.mouse.get_pos() +############### DEPRECATED: see geneticsd.py pretty_print(f"Detected! Click at {pos}") +############### DEPRECATED: see geneticsd.py if pos[1] > Y: +############### DEPRECATED: see geneticsd.py pretty_print("Let us update parameters!") +############### DEPRECATED: see geneticsd.py text4 = font.render(to_native(f"ok, go to text window!"), True, green, blue) +############### DEPRECATED: see geneticsd.py scrn.blit(text4, (300, Y + 30)) +############### DEPRECATED: see geneticsd.py pygame.display.flip() +############### DEPRECATED: see geneticsd.py try: +############### DEPRECATED: see geneticsd.py num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n"))) +############### DEPRECATED: see geneticsd.py except: +############### DEPRECATED: see geneticsd.py num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n"))) +############### DEPRECATED: see geneticsd.py gs = float(input(to_native(f"Guidance scale ? (current = {gs})\n"))) +############### DEPRECATED: see geneticsd.py print(to_native(f"The current text is << {prompt} >>.")) +############### DEPRECATED: see geneticsd.py print(to_native("Start your answer with a symbol << + >> if this is an edit and not a new text.")) +############### DEPRECATED: see geneticsd.py new_prompt = str(input(to_native(f"Enter a text if you want to change from ") + prompt)) +############### DEPRECATED: see geneticsd.py if len(new_prompt) > 2: +############### DEPRECATED: see geneticsd.py if new_prompt[0] == "+": +############### DEPRECATED: see geneticsd.py prompt += new_prompt[1:] +############### DEPRECATED: see geneticsd.py else: +############### DEPRECATED: see geneticsd.py prompt = new_prompt +############### DEPRECATED: see geneticsd.py language = detect(prompt) +############### DEPRECATED: see geneticsd.py english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt) +############### DEPRECATED: see geneticsd.py pretty_print("Ok! Parameters updated.") +############### DEPRECATED: see geneticsd.py pretty_print("==> go back to the window!") +############### DEPRECATED: see geneticsd.py text4 = font.render(to_native(f"Ok! parameters changed!"), True, green, blue) +############### DEPRECATED: see geneticsd.py scrn.blit(text4, (300, Y + 30)) +############### DEPRECATED: see geneticsd.py pygame.display.flip() +############### DEPRECATED: see geneticsd.py elif pos[0] > 1500: # Not in the images. +############### DEPRECATED: see geneticsd.py if pos[1] < Y/3: +############### DEPRECATED: see geneticsd.py #filename = input(to_native("Filename (please provide the latent file, of the format SD*latent*.txt) ?\n")) +############### DEPRECATED: see geneticsd.py #status = False +############### DEPRECATED: see geneticsd.py #with open(filename, 'r') as f: +############### DEPRECATED: see geneticsd.py # latent = f.read() +############### DEPRECATED: see geneticsd.py #break +############### DEPRECATED: see geneticsd.py pretty_print("Easy! I exit now, you edit the file and you save it.") +############### DEPRECATED: see geneticsd.py pretty_print("Then just relaunch me and provide the text and the image.") +############### DEPRECATED: see geneticsd.py exit() +############### DEPRECATED: see geneticsd.py if pos[1] < 2*Y/3: +############### DEPRECATED: see geneticsd.py #onlyfiles = [f for f in listdir(".") if isfile(join(mypath, f))] +############### DEPRECATED: see geneticsd.py #onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] +############### DEPRECATED: see geneticsd.py assert len(onlyfiles) == len(latent) +############### DEPRECATED: see geneticsd.py assert len(all_selected) == len(all_selected_latent) +############### DEPRECATED: see geneticsd.py stop_all(all_selected, all_selected_latent, final_selection, final_selection_latent) # + onlyfiles, all_selected_latent + latent) +############### DEPRECATED: see geneticsd.py exit() +############### DEPRECATED: see geneticsd.py status = False +############### DEPRECATED: see geneticsd.py break +############### DEPRECATED: see geneticsd.py index = 3 * (pos[0] // 300) + (pos[1] // 300) +############### DEPRECATED: see geneticsd.py pygame.draw.circle(scrn, red, [pos[0], pos[1]], 13, 0) +############### DEPRECATED: see geneticsd.py if index <= max_created_index: +############### DEPRECATED: see geneticsd.py selected_filename = to_native("Selected") + onlyfiles[index] +############### DEPRECATED: see geneticsd.py shutil.copyfile(onlyfiles[index], selected_filename) +############### DEPRECATED: see geneticsd.py assert len(onlyfiles) == len(latent), f"{len(onlyfiles)} != {len(latent)}" +############### DEPRECATED: see geneticsd.py all_selected += [selected_filename] +############### DEPRECATED: see geneticsd.py all_selected_latent += [latent[index]] +############### DEPRECATED: see geneticsd.py final_selection += [selected_filename] +############### DEPRECATED: see geneticsd.py final_selection_latent += [latent[index]] +############### DEPRECATED: see geneticsd.py text2 = font.render(to_native(f'==> {len(all_selected)} chosen images! '), True, green, blue) +############### DEPRECATED: see geneticsd.py text2 = pygame.transform.rotate(text2, 90) +############### DEPRECATED: see geneticsd.py scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) +############### DEPRECATED: see geneticsd.py if index not in five_best and len(five_best) < 5: +############### DEPRECATED: see geneticsd.py five_best += [index] +############### DEPRECATED: see geneticsd.py indices += [[index, (pos[0] - (pos[0] // 300) * 300) / 300, (pos[1] - (pos[1] // 300) * 300) / 300]] +############### DEPRECATED: see geneticsd.py # Update the button for new generation. +############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, black, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y)) +############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y), 2) +############### DEPRECATED: see geneticsd.py text3 = font.render(to_native(f" You have chosen {len(indices)} images:"), True, green, blue) +############### DEPRECATED: see geneticsd.py text3 = pygame.transform.rotate(text3, 90) +############### DEPRECATED: see geneticsd.py scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3)) +############### DEPRECATED: see geneticsd.py text3 = font.render(to_native(f" Click for new generation!"), True, green, blue) +############### DEPRECATED: see geneticsd.py text3 = pygame.transform.rotate(text3, 90) +############### DEPRECATED: see geneticsd.py scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3)) +############### DEPRECATED: see geneticsd.py pygame.display.flip() +############### DEPRECATED: see geneticsd.py #text3Rect = text3.get_rect() +############### DEPRECATED: see geneticsd.py #text3Rect.center = (750+750*3/4, 1000) +############### DEPRECATED: see geneticsd.py good += [list(latent[index].flatten())] +############### DEPRECATED: see geneticsd.py else: +############### DEPRECATED: see geneticsd.py noise.say("Bad click! Click on image.") +############### DEPRECATED: see geneticsd.py noise.runAndWait() +############### DEPRECATED: see geneticsd.py pretty_print("Bad click! Click on image.") +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py if i.type == pygame.QUIT: +############### DEPRECATED: see geneticsd.py status = False +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py # Covering old images with full circles. +############### DEPRECATED: see geneticsd.py for _ in range(123): +############### DEPRECATED: see geneticsd.py x = np.random.randint(1500) +############### DEPRECATED: see geneticsd.py y = np.random.randint(900) +############### DEPRECATED: see geneticsd.py pygame.draw.circle(scrn, darkgreen, +############### DEPRECATED: see geneticsd.py [x, y], 17, 0) +############### DEPRECATED: see geneticsd.py pygame.display.update() +############### DEPRECATED: see geneticsd.py if len(indices) == 0: +############### DEPRECATED: see geneticsd.py print("The user did not like anything! Rerun :-(") +############### DEPRECATED: see geneticsd.py continue +############### DEPRECATED: see geneticsd.py print(f"Clicks at {indices}") +############### DEPRECATED: see geneticsd.py os.environ["mu"] = str(len(indices)) +############### DEPRECATED: see geneticsd.py forcedlatents = [] +############### DEPRECATED: see geneticsd.py bad += [list(latent[u].flatten()) for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]] +############### DEPRECATED: see geneticsd.py #sauron = 0 * latent[0] +############### DEPRECATED: see geneticsd.py #for u in [u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]]: +############### DEPRECATED: see geneticsd.py # sauron += latent[u] +############### DEPRECATED: see geneticsd.py #sauron = (1 / len([u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]])) * sauron +############### DEPRECATED: see geneticsd.py if len(bad) > 500: +############### DEPRECATED: see geneticsd.py bad = bad[(len(bad) - 500):] +############### DEPRECATED: see geneticsd.py print(to_native(f"{len(indices)} indices are selected.")) +############### DEPRECATED: see geneticsd.py #print(f"indices = {indices}") +############### DEPRECATED: see geneticsd.py os.environ["good"] = str(good) +############### DEPRECATED: see geneticsd.py os.environ["bad"] = str(bad) +############### DEPRECATED: see geneticsd.py coefficients = np.zeros(len(indices)) +############### DEPRECATED: see geneticsd.py numpy_images = [np.array(image) for image in images] +############### DEPRECATED: see geneticsd.py for a in range(llambda): +############### DEPRECATED: see geneticsd.py voronoi_in_images = False #(a % 2 == 1) and len(good) > 1 +############### DEPRECATED: see geneticsd.py if voronoi_in_images: +############### DEPRECATED: see geneticsd.py image = np.array(numpy_images[0]) +############### DEPRECATED: see geneticsd.py print(f"Voronoi in the image space! {a} / {llambda}") +############### DEPRECATED: see geneticsd.py for i in range(len(indices)): +############### DEPRECATED: see geneticsd.py coefficients[i] = np.exp(np.random.randn()) +############### DEPRECATED: see geneticsd.py # Creating a forcedlatent. +############### DEPRECATED: see geneticsd.py for i in range(512): +############### DEPRECATED: see geneticsd.py x = i / 511. +############### DEPRECATED: see geneticsd.py for j in range(512): +############### DEPRECATED: see geneticsd.py y = j / 511 +############### DEPRECATED: see geneticsd.py mindistances = 10000000000. +############### DEPRECATED: see geneticsd.py for u in range(len(indices)): +############### DEPRECATED: see geneticsd.py distance = coefficients[u] * np.linalg.norm( np.array((x, y)) - np.array((indices[u][2], indices[u][1])) ) +############### DEPRECATED: see geneticsd.py if distance < mindistances: +############### DEPRECATED: see geneticsd.py mindistances = distance +############### DEPRECATED: see geneticsd.py uu = indices[u][0] +############### DEPRECATED: see geneticsd.py image[i][j][:] = numpy_images[uu][i][j][:] +############### DEPRECATED: see geneticsd.py # Conversion before using img2latent +############### DEPRECATED: see geneticsd.py pil_image = Image.fromarray(image) +############### DEPRECATED: see geneticsd.py voronoi_name = f"voronoi{a}_iteration{iteration}.png" +############### DEPRECATED: see geneticsd.py pil_image.save(voronoi_name) +############### DEPRECATED: see geneticsd.py #timage = np.array([image]).astype(np.float32) / 255.0 +############### DEPRECATED: see geneticsd.py #timage = timage.transpose(0, 3, 1, 2) +############### DEPRECATED: see geneticsd.py #timage = torch.from_numpy(timage).to(device) +############### DEPRECATED: see geneticsd.py #timage = repeat(timage, '1 ... -> b ...', b=1) +############### DEPRECATED: see geneticsd.py #timage = 2.*timage - 1. +############### DEPRECATED: see geneticsd.py #forcedlatent = model.encode(timage).latent_dist.sample().cpu().detach().numpy().flatten() +############### DEPRECATED: see geneticsd.py #basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent +############### DEPRECATED: see geneticsd.py basic_new_fl = randomized_image_to_latent(voronoi_name) #img_to_latent(voronoi_name) +############### DEPRECATED: see geneticsd.py basic_new_fl = np.sqrt(len(basic_new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl +############### DEPRECATED: see geneticsd.py #basic_new_fl = 0.8 * np.sqrt(len(basic_new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl +############### DEPRECATED: see geneticsd.py if len(good) > 1: +############### DEPRECATED: see geneticsd.py print("Directly copying latent vars !!!") +############### DEPRECATED: see geneticsd.py #forcedlatents += [4.6 * basic_new_fl] +############### DEPRECATED: see geneticsd.py forcedlatents += [basic_new_fl] +############### DEPRECATED: see geneticsd.py else: +############### DEPRECATED: see geneticsd.py epsilon = 1.0 * (((a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2) +############### DEPRECATED: see geneticsd.py forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) +############### DEPRECATED: see geneticsd.py forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent +############### DEPRECATED: see geneticsd.py forcedlatents += [forcedlatent] +############### DEPRECATED: see geneticsd.py #forcedlatents += [4.6 * forcedlatent] +############### DEPRECATED: see geneticsd.py else: +############### DEPRECATED: see geneticsd.py print(f"Voronoi in the latent space! {a} / {llambda}") +############### DEPRECATED: see geneticsd.py forcedlatent = np.zeros((4, 64, 64)) +############### DEPRECATED: see geneticsd.py #print(type(numpy_image)) +############### DEPRECATED: see geneticsd.py #print(numpy_image.shape) +############### DEPRECATED: see geneticsd.py #print(np.max(numpy_image)) +############### DEPRECATED: see geneticsd.py #print(np.min(numpy_image)) +############### DEPRECATED: see geneticsd.py #assert False +############### DEPRECATED: see geneticsd.py for i in range(len(indices)): +############### DEPRECATED: see geneticsd.py coefficients[i] = np.exp(np.random.randn()) +############### DEPRECATED: see geneticsd.py for i in range(64): +############### DEPRECATED: see geneticsd.py x = i / 63. +############### DEPRECATED: see geneticsd.py for j in range(64): +############### DEPRECATED: see geneticsd.py y = j / 63 +############### DEPRECATED: see geneticsd.py mindistances = 10000000000. +############### DEPRECATED: see geneticsd.py for u in range(len(indices)): +############### DEPRECATED: see geneticsd.py #print(a, i, x, j, y, u) +############### DEPRECATED: see geneticsd.py #print(indices[u][1]) +############### DEPRECATED: see geneticsd.py #print(indices[u][2]) +############### DEPRECATED: see geneticsd.py #print(f" {coefficients[u]}* np.linalg.norm({np.array((x, y))}-{np.array((indices[u][1], indices[u][2]))}") +############### DEPRECATED: see geneticsd.py distance = coefficients[u] * np.linalg.norm( np.array((x, y)) - np.array((indices[u][2], indices[u][1])) ) +############### DEPRECATED: see geneticsd.py if distance < mindistances: +############### DEPRECATED: see geneticsd.py mindistances = distance +############### DEPRECATED: see geneticsd.py uu = indices[u][0] +############### DEPRECATED: see geneticsd.py for k in range(4): +############### DEPRECATED: see geneticsd.py assert k < len(forcedlatent), k +############### DEPRECATED: see geneticsd.py assert i < len(forcedlatent[k]), i +############### DEPRECATED: see geneticsd.py assert j < len(forcedlatent[k][i]), j +############### DEPRECATED: see geneticsd.py assert uu < len(latent) +############### DEPRECATED: see geneticsd.py assert k < len(latent[uu]), k +############### DEPRECATED: see geneticsd.py assert i < len(latent[uu][k]), i +############### DEPRECATED: see geneticsd.py assert j < len(latent[uu][k][i]), j +############### DEPRECATED: see geneticsd.py forcedlatent[k][i][j] = float(latent[uu][k][i][j]) +############### DEPRECATED: see geneticsd.py #if a % 2 == 0: +############### DEPRECATED: see geneticsd.py # forcedlatent -= np.random.rand() * sauron +############### DEPRECATED: see geneticsd.py forcedlatent = forcedlatent.flatten() +############### DEPRECATED: see geneticsd.py basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent +############### DEPRECATED: see geneticsd.py if len(good) > 1 or len(forcedlatents) < len(good) + 1: +############### DEPRECATED: see geneticsd.py forcedlatents += [basic_new_fl] +############### DEPRECATED: see geneticsd.py else: +############### DEPRECATED: see geneticsd.py epsilon = ((0.5 * (a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2) +############### DEPRECATED: see geneticsd.py forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) +############### DEPRECATED: see geneticsd.py #forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent +############### DEPRECATED: see geneticsd.py forcedlatents += [forcedlatent] +############### DEPRECATED: see geneticsd.py #for uu in range(len(latent)): +############### DEPRECATED: see geneticsd.py # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}") +############### DEPRECATED: see geneticsd.py os.environ["good"] = "[]" +############### DEPRECATED: see geneticsd.py os.environ["bad"] = "[]" +############### DEPRECATED: see geneticsd.py +############### DEPRECATED: see geneticsd.py pygame.quit() From ff1961cd6b3c125ee79cef8aaba28d996639d61c Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 6 Oct 2022 14:48:57 +0200 Subject: [PATCH 68/76] fix --- geneticsd.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/geneticsd.py b/geneticsd.py index f30a493ee..66e9f4220 100644 --- a/geneticsd.py +++ b/geneticsd.py @@ -152,6 +152,8 @@ def speak(text): prompt = "Photo of the devil, with horns. There are flames in the background." prompt = "Yann LeCun fighting Pinocchio with light sabers." prompt = "Yann LeCun attacks a triceratops with a lightsaber." +prompt = "A cyberpunk man next to a cyberpunk woman." +prompt = "A smiling woman with a Katana and electronic patches." print(f"The prompt is {prompt}") @@ -337,6 +339,7 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N new_fl = scale * np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2)) #image = latent_to_image(np.asarray(new_fl)) #eval(os.environ["forcedlatent"]))) #image.save(f"rebuild_{f}_{scale}_{epsilon}_{c}.png") + #gs=7.5, f=0.12, scale=3.7, epsilon=0.01,1 c=2.05 return new_fl # In case the user wants to start from a given image. From 9b96c3c59cc479132a6cca0f79c13c176ae143d6 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 6 Oct 2022 15:07:32 +0200 Subject: [PATCH 69/76] voronoi --- geneticsd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/geneticsd.py b/geneticsd.py index 66e9f4220..749ae631e 100644 --- a/geneticsd.py +++ b/geneticsd.py @@ -154,6 +154,7 @@ def speak(text): prompt = "Yann LeCun attacks a triceratops with a lightsaber." prompt = "A cyberpunk man next to a cyberpunk woman." prompt = "A smiling woman with a Katana and electronic patches." +prompt = "Photo of a bearded, long-haired man and a blonde-haired woman. Cats and drums and computers on shelves in the background." print(f"The prompt is {prompt}") From aaa55eabfa7134d0092f091e7c071d15fd896e27 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 6 Oct 2022 16:01:34 +0200 Subject: [PATCH 70/76] fix --- README.md | 5 +++++ geneticsd.py | 26 ++++++++++++++++++-------- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 41140684f..94900c306 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,12 @@ Just click here and copy-paste your token: ## Install StableDiffusion as usual, plus a few more stuff. Basically: +You need homebrew. +You need to open a terminal. Then: ``` +mkdir stablediffusion +cd stablediffusion +git clone git@github.com:teytaud/genetic-stable-diffusion.git . brew install wget conda env create -f environment.yaml conda activate ldm # you can change that name in the environment.yaml file... diff --git a/geneticsd.py b/geneticsd.py index 749ae631e..5157656dd 100644 --- a/geneticsd.py +++ b/geneticsd.py @@ -416,7 +416,13 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N text0 = font.render(to_native(f'... ... ... '), True, green, blue) scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2+Y/8)) - # Button for early stopping + text1 = font.render(to_native('Undo: click for '), True, green, blue) + text1 = pygame.transform.rotate(text2, 90) + scrn.blit(text1, (X*3/4+X/16+X/64 - X/32, Y/12)) + text1 = font.render(to_native('resetting your clicks.'), True, green, blue) + text1 = pygame.transform.rotate(text2, 90) + scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, Y/12)) + # Button for quitting and effects text2 = font.render(to_native(f'Total: {len(all_selected)} chosen images! '), True, green, blue) text2 = pygame.transform.rotate(text2, 90) scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3)) @@ -488,20 +494,24 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N pygame.draw.rect(scrn, red, pygame.Rect(0, Y, X/2, Y+100), 2) # Button for loading a starting point - text1 = font.render('Manually edit an image.', True, green, blue) - text1 = pygame.transform.rotate(text1, 90) + #text1 = font.render('Manually edit an image.', True, green, blue) + #text1 = pygame.transform.rotate(text1, 90) #scrn.blit(text1, (X*3/4+X/16 - X/32, 0)) #text1 = font.render('& latent ', True, green, blue) #text1 = pygame.transform.rotate(text1, 90) #scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, 0)) - # Button for creating a meme + # Button for stopping now. text2 = font.render(to_native('Click ,'), True, green, blue) text2 = pygame.transform.rotate(text2, 90) scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3+10)) text2 = font.render(to_native('for finishing with effects.'), True, green, blue) text2 = pygame.transform.rotate(text2, 90) scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3+10)) + text2 = font.render(to_native('or manually edit.'), True, green, blue) + text2 = pygame.transform.rotate(text2, 90) + scrn.blit(text2, (X*3/4+X/16+X/32 , Y/3+10)) + # Button for new generation text3 = font.render(to_native(f"I don't want to select images"), True, green, blue) text3 = pygame.transform.rotate(text3, 90) @@ -575,9 +585,9 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N #with open(filename, 'r') as f: # latent = f.read() #break - pretty_print("Easy! I exit now, you edit the file and you save it.") - pretty_print("Then just relaunch me and provide the text and the image.") - exit() + #pretty_print("Easy! I exit now, you edit the file and you save it.") + #pretty_print("Then just relaunch me and provide the text and the image.") + #exit() if pos[1] < 2*Y/3: #onlyfiles = [f for f in listdir(".") if isfile(join(mypath, f))] #onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)] @@ -734,7 +744,7 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N else: epsilon = ((0.5 * (a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2) forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) - forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent + #forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent REMOVED!! forcedlatents += [forcedlatent] #for uu in range(len(latent)): # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}") From 7125a08d0f870ac67cf423b7e92682442ab94af1 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 6 Oct 2022 16:03:05 +0200 Subject: [PATCH 71/76] fix --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 94900c306..43d84c198 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,10 @@ Just click here and copy-paste your token: ## Install StableDiffusion as usual, plus a few more stuff. Basically: You need homebrew. +On a Mac, you need to do special stuff for the MPS: we recommend +[**This page**](https://towardsdatascience.com/gpu-acceleration-comes-to-pytorch-on-m1-macs-195c399efcc1)
+ +[ You need to open a terminal. Then: ``` mkdir stablediffusion From f5dea90ff561f51ea432fdf679f8afda56b0ee40 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 6 Oct 2022 16:06:47 +0200 Subject: [PATCH 72/76] fix --- geneticsd.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/geneticsd.py b/geneticsd.py index 5157656dd..52bf77dd2 100644 --- a/geneticsd.py +++ b/geneticsd.py @@ -580,6 +580,10 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N pygame.display.flip() elif pos[0] > 1500: # Not in the images. if pos[1] < Y/3: + indices = [] + good = [] + final_selection = [] + final_selection_latent = [] #filename = input(to_native("Filename (please provide the latent file, of the format SD*latent*.txt) ?\n")) #status = False #with open(filename, 'r') as f: @@ -738,10 +742,11 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N #if a % 2 == 0: # forcedlatent -= np.random.rand() * sauron forcedlatent = forcedlatent.flatten() - basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent if len(good) > 1 or len(forcedlatents) < len(good) + 1: + basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent forcedlatents += [basic_new_fl] else: + basic_new_fl = forcedlatent epsilon = ((0.5 * (a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2) forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) #forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent REMOVED!! From 25300f1be19922903c3791eda105be2d5bcb2a19 Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 6 Oct 2022 16:08:54 +0200 Subject: [PATCH 73/76] fixdoc --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 43d84c198..d38921943 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,6 @@ You need homebrew. On a Mac, you need to do special stuff for the MPS: we recommend [**This page**](https://towardsdatascience.com/gpu-acceleration-comes-to-pytorch-on-m1-macs-195c399efcc1)
-[ You need to open a terminal. Then: ``` mkdir stablediffusion From 00b8331da91e1a78b60e2854a6153bcb139598ea Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 6 Oct 2022 16:10:04 +0200 Subject: [PATCH 74/76] fix --- geneticsd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/geneticsd.py b/geneticsd.py index 52bf77dd2..6a8cec978 100644 --- a/geneticsd.py +++ b/geneticsd.py @@ -417,10 +417,10 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2+Y/8)) text1 = font.render(to_native('Undo: click for '), True, green, blue) - text1 = pygame.transform.rotate(text2, 90) + text1 = pygame.transform.rotate(text1, 90) scrn.blit(text1, (X*3/4+X/16+X/64 - X/32, Y/12)) text1 = font.render(to_native('resetting your clicks.'), True, green, blue) - text1 = pygame.transform.rotate(text2, 90) + text1 = pygame.transform.rotate(text1, 90) scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, Y/12)) # Button for quitting and effects text2 = font.render(to_native(f'Total: {len(all_selected)} chosen images! '), True, green, blue) From 9730d9078f5f200830ec0797142104fd2e61b8bf Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 6 Oct 2022 16:31:11 +0200 Subject: [PATCH 75/76] fix --- geneticsd.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/geneticsd.py b/geneticsd.py index 6a8cec978..d0e1d42fb 100644 --- a/geneticsd.py +++ b/geneticsd.py @@ -154,7 +154,7 @@ def speak(text): prompt = "Yann LeCun attacks a triceratops with a lightsaber." prompt = "A cyberpunk man next to a cyberpunk woman." prompt = "A smiling woman with a Katana and electronic patches." -prompt = "Photo of a bearded, long-haired man and a blonde-haired woman. Cats and drums and computers on shelves in the background." +prompt = "Photo of a bearded, long-haired man with glasses and a blonde-haired woman. Both are smiling. Cats and drums and computers on shelves in the background." print(f"The prompt is {prompt}") @@ -747,9 +747,10 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N forcedlatents += [basic_new_fl] else: basic_new_fl = forcedlatent + coef = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) + coef = np.exp(np.log(coef) * (a/llambda) ) epsilon = ((0.5 * (a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2) - forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) - #forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent REMOVED!! + forcedlatent = (1. - epsilon) * coef * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) forcedlatents += [forcedlatent] #for uu in range(len(latent)): # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}") From 9dd9a5c820c4a7cf3cf81efe9c196d4d3b77e9cc Mon Sep 17 00:00:00 2001 From: Olivier Teytaud Date: Thu, 6 Oct 2022 19:45:34 +0200 Subject: [PATCH 76/76] fix --- README.md | 1 + geneticsd.py | 25 +++++++++++++++++++------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index d38921943..fc92df556 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ pip install langdetect pip install deep-translator pip install git+https://github.com/sberbank-ai/Real-ESRGAN.git wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P weights +wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth ``` ## Then run << python geneticsd.py >>. diff --git a/geneticsd.py b/geneticsd.py index d0e1d42fb..e6f44136c 100644 --- a/geneticsd.py +++ b/geneticsd.py @@ -1,4 +1,6 @@ # A ton of imports. +from gfpgan.utils import GFPGANer +import cv2 import random import os import time @@ -196,6 +198,12 @@ def latent_to_image(latent): esrmodel2 = RealESRGAN(sr_device, scale=2) esrmodel2.load_weights('weights/RealESRGAN_x2.pth', download=True) +def fe(path): + fe = GFPGANer(model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2) + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + _, _, output = fe.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) + cv2.imwrite(path, output) + def singleeg(path_to_image): image = Image.open(path_to_image).convert('RGB') sr_device = device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -204,6 +212,7 @@ def singleeg(path_to_image): print(f"Type after SR = {type(sr_image)}") output_filename = path_to_image + ".SR.png" sr_image.save(output_filename) + fe(output_filename) return output_filename # A version with x2. @@ -215,6 +224,7 @@ def singleeg2(path_to_image): print(f"Type after SR = {type(sr_image)}") output_filename = path_to_image + ".SR.png" sr_image.save(output_filename) + fe(output_filename) return output_filename @@ -265,6 +275,7 @@ def stop_all(list_of_files, list_of_latent, last_list_of_files, last_list_of_lat image = latent_to_image(l + cc * (l1 - l) + ss * (l2 - l)) image_name = f"imgA{index}.png" image.save(image_name) + fe(image_name) images += [image_name] print(to_native(f"Base images created for perturbation={c} and file {list_of_files[idx]}")) @@ -297,6 +308,7 @@ def stop_all(list_of_files, list_of_latent, last_list_of_files, last_list_of_lat Y = 900 scrn = pygame.display.set_mode((1700, Y + 100)) font = pygame.font.Font('freesansbold.ttf', 22) +minifont = pygame.font.Font('freesansbold.ttf', 11) bigfont = pygame.font.Font('freesansbold.ttf', 44) def load_img(path): @@ -416,7 +428,7 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N text0 = font.render(to_native(f'... ... ... '), True, green, blue) scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2+Y/8)) - text1 = font.render(to_native('Undo: click for '), True, green, blue) + text1 = minifont.render(to_native('Undo: click for '), True, green, blue) text1 = pygame.transform.rotate(text1, 90) scrn.blit(text1, (X*3/4+X/16+X/64 - X/32, Y/12)) text1 = font.render(to_native('resetting your clicks.'), True, green, blue) @@ -447,6 +459,7 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N images += [image] filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{iteration:05d}_{k:05d}.png" image.save(filename) + fe(filename) onlyfiles += [filename] imp = pygame.transform.scale(pygame.image.load(onlyfiles[-1]).convert(), (300, 300)) scrn.blit(imp, (300 * (k // 3), 300 * (k % 3))) @@ -742,15 +755,15 @@ def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=N #if a % 2 == 0: # forcedlatent -= np.random.rand() * sauron forcedlatent = forcedlatent.flatten() + basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent if len(good) > 1 or len(forcedlatents) < len(good) + 1: - basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent forcedlatents += [basic_new_fl] else: - basic_new_fl = forcedlatent + epsilon = (( (a + .5 - len(good)) / (llambda - len(good) - 1))) + forcedlatent = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(4*64*64) coef = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) - coef = np.exp(np.log(coef) * (a/llambda) ) - epsilon = ((0.5 * (a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2) - forcedlatent = (1. - epsilon) * coef * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64) + forcedlatent = coef * forcedlatent + print("we get ", sum(forcedlatent) ** 2) forcedlatents += [forcedlatent] #for uu in range(len(latent)): # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}")