diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 000000000..9437aef91
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,3 @@
+*Dockerfile*
+docker-compose.yml
+.git
diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml
new file mode 100644
index 000000000..81ba6f6ef
--- /dev/null
+++ b/.github/FUNDING.yml
@@ -0,0 +1,13 @@
+# These are supported funding model platforms
+
+github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
+patreon: # Replace with a single Patreon username
+open_collective: # Replace with a single Open Collective username
+ko_fi: basuj
+tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
+community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
+liberapay: # Replace with a single Liberapay username
+issuehunt: # Replace with a single IssueHunt username
+otechie: # Replace with a single Otechie username
+lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
+custom: ['https://paypal.me/basuj']
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 000000000..69b1eea68
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,47 @@
+FROM continuumio/miniconda3:4.12.0 AS build
+
+# Step for image utility dependencies.
+RUN apt update \
+ && apt install --no-install-recommends -y git \
+ && apt-get clean
+
+COPY . /root/stable-diffusion/
+
+# Step to install dependencies with conda
+RUN eval "$(conda shell.bash hook)" \
+ && conda install -c conda-forge conda-pack \
+ && conda env create -f /root/stable-diffusion/environment.yaml \
+ && conda activate ldm \
+ && pip install gradio==3.1.7 \
+ && conda activate base
+
+# Step to zip and conda environment to "venv" folder
+RUN conda pack --ignore-missing-files --ignore-editable-packages -n ldm -o /tmp/env.tar \
+ && mkdir /venv \
+ && cd /venv \
+ && tar xf /tmp/env.tar \
+ && rm /tmp/env.tar
+
+FROM nvidia/cuda:11.8.0-base-ubuntu22.04 as runtime
+
+ARG OPTIMIZED_FILE=txt2img_gradio.py
+WORKDIR /root/stable-diffusion
+
+COPY --from=build /venv /venv
+COPY --from=build /root/stable-diffusion /root/stable-diffusion
+
+RUN mkdir -p /output /root/stable-diffusion/outputs \
+ && ln -s /data /root/stable-diffusion/models/ldm/stable-diffusion-v1 \
+ && ln -s /output /root/stable-diffusion/outputs/txt2img-samples
+
+ENV PYTHONUNBUFFERED=1
+ENV GRADIO_SERVER_NAME=0.0.0.0
+ENV GRADIO_SERVER_PORT=7860
+ENV APP_MAIN_FILE=${OPTIMIZED_FILE}
+EXPOSE 7860
+
+VOLUME ["/root/.cache", "/data", "/output"]
+
+SHELL ["/bin/bash", "-c"]
+ENTRYPOINT ["/root/stable-diffusion/docker-bootstrap.sh"]
+CMD python optimizedSD/${APP_MAIN_FILE}
\ No newline at end of file
diff --git a/README.md b/README.md
index 63e96be27..a37b471bc 100644
--- a/README.md
+++ b/README.md
@@ -1,181 +1,148 @@
-# Stable Diffusion
-*Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:*
-
-[**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)
-[Robin Rombach](https://github.com/rromb)\*,
-[Andreas Blattmann](https://github.com/ablattmann)\*,
-[Dominik Lorenz](https://github.com/qp-qp)\,
-[Patrick Esser](https://github.com/pesser),
-[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
-
-which is available on [GitHub](https://github.com/CompVis/latent-diffusion).
-
-
-[Stable Diffusion](#stable-diffusion-v1) is a latent text-to-image diffusion
-model.
-Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database.
-Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487),
-this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts.
-With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 10GB VRAM.
-See [this section](#stable-diffusion-v1) below and the [model card](https://huggingface.co/CompVis/stable-diffusion).
-
-
-## Requirements
-A suitable [conda](https://conda.io/) environment named `ldm` can be created
-and activated with:
-
-```
-conda env create -f environment.yaml
-conda activate ldm
-```
-
-You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running
-
-```
-conda install pytorch torchvision -c pytorch
-pip install transformers==4.19.2
-pip install -e .
-```
-
-
-## Stable Diffusion v1
-
-Stable Diffusion v1 refers to a specific configuration of the model
-architecture that uses a downsampling-factor 8 autoencoder with an 860M UNet
-and CLIP ViT-L/14 text encoder for the diffusion model. The model was pretrained on 256x256 images and
-then finetuned on 512x512 images.
-
-*Note: Stable Diffusion v1 is a general text-to-image diffusion model and therefore mirrors biases and (mis-)conceptions that are present
-in its training data.
-Details on the training procedure and data, as well as the intended use of the model can be found in the corresponding [model card](https://huggingface.co/CompVis/stable-diffusion).
-Research into the safe deployment of general text-to-image models is an ongoing effort. To prevent misuse and harm, we currently provide access to the checkpoints only for [academic research purposes upon request](https://stability.ai/academia-access-form).
-**This is an experiment in safe and community-driven publication of a capable and general text-to-image model. We are working on a public release with a more permissive license that also incorporates ethical considerations.***
-
-[Request access to Stable Diffusion v1 checkpoints for academic research](https://stability.ai/academia-access-form)
-
-### Weights
-
-We currently provide three checkpoints, `sd-v1-1.ckpt`, `sd-v1-2.ckpt` and `sd-v1-3.ckpt`,
-which were trained as follows,
-
-- `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
- 194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
-- `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`.
- 515k steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en,
-filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
-- `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-improved-aesthetics" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
-
-Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
-5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling
-steps show the relative improvements of the checkpoints:
-
-
-
-
-### Text-to-Image with Stable Diffusion
-
-
-
-Stable Diffusion is a latent diffusion model conditioned on the (non-pooled) text embeddings of a CLIP ViT-L/14 text encoder.
-
-After [obtaining the weights](#weights), link them
-```
-mkdir -p models/ldm/stable-diffusion-v1/
-ln -s models/ldm/stable-diffusion-v1/model.ckpt
-```
-and sample with
-```
-python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
-```
-By default, this uses a guidance scale of `--scale 7.5`, [Katherine Crowson's implementation](https://github.com/CompVis/latent-diffusion/pull/51) of the [PLMS](https://arxiv.org/abs/2202.09778) sampler,
-and renders images of size 512x512 (which it was trained on) in 50 steps. All supported arguments are listed below (type `python scripts/txt2img.py --help`).
-
-```commandline
-usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA] [--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS]
- [--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT] [--seed SEED] [--precision {full,autocast}]
-
-optional arguments:
- -h, --help show this help message and exit
- --prompt [PROMPT] the prompt to render
- --outdir [OUTDIR] dir to write results to
- --skip_grid do not save a grid, only individual samples. Helpful when evaluating lots of samples
- --skip_save do not save individual samples. For speed measurements.
- --ddim_steps DDIM_STEPS
- number of ddim sampling steps
- --plms use plms sampling
- --laion400m uses the LAION400M model
- --fixed_code if enabled, uses the same starting code across samples
- --ddim_eta DDIM_ETA ddim eta (eta=0.0 corresponds to deterministic sampling
- --n_iter N_ITER sample this often
- --H H image height, in pixel space
- --W W image width, in pixel space
- --C C latent channels
- --f F downsampling factor
- --n_samples N_SAMPLES
- how many samples to produce for each given prompt. A.k.a. batch size
- --n_rows N_ROWS rows in the grid (default: n_samples)
- --scale SCALE unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))
- --from-file FROM_FILE
- if specified, load prompts from this file
- --config CONFIG path to config which constructs model
- --ckpt CKPT path to checkpoint of model
- --seed SEED the seed (for reproducible sampling)
- --precision {full,autocast}
- evaluate at this precision
+Optimized Stable Diffusion
+
+
+
+
+
-```
-Note: The inference config for all v1 versions is designed to be used with EMA-only checkpoints.
-For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from
-non-EMA to EMA weights. If you want to examine the effect of EMA vs no EMA, we provide "full" checkpoints
-which contain both types of weights. For these, `use_ema=False` will load and use the non-EMA weights.
+This repo is a modified version of the Stable Diffusion repo, optimized to use less VRAM than the original by sacrificing inference speed.
+To reduce the VRAM usage, the following opimizations are used:
-### Image Modification with Stable Diffusion
-
-By using a diffusion-denoising mechanism as first proposed by [SDEdit](https://arxiv.org/abs/2108.01073), the model can be used for different
-tasks such as text-guided image-to-image translation and upscaling. Similar to the txt2img sampling script,
-we provide a script to perform image modification with Stable Diffusion.
+- the stable diffusion model is fragmented into four parts which are sent to the GPU only when needed. After the calculation is done, they are moved back to the CPU.
+- The attention calculation is done in parts.
-The following describes an example where a rough sketch made in [Pinta](https://www.pinta-project.com/) is converted into a detailed artwork.
-```
-python scripts/img2img.py --prompt "A fantasy landscape, trending on artstation" --init-img --strength 0.8
-```
-Here, strength is a value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.
-Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input. See the following example.
-
-**Input**
+Installation
-
-
-**Outputs**
+All the modified files are in the [optimizedSD](optimizedSD) folder, so if you have already cloned the original repository you can just download and copy this folder into the original instead of cloning the entire repo. You can also clone this repo and follow the same installation steps as the original (mainly creating the conda environment and placing the weights at the specified location).
-
-
-
-This procedure can, for example, also be used to upscale samples from the base model.
-
-
-## Comments
-
-- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
-and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
-Thanks for open-sourcing!
-
-- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories).
-
-
-## BibTeX
-
-```
-@misc{rombach2021highresolution,
- title={High-Resolution Image Synthesis with Latent Diffusion Models},
- author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
- year={2021},
- eprint={2112.10752},
- archivePrefix={arXiv},
- primaryClass={cs.CV}
-}
-
-```
+Alternatively, if you prefer to use Docker, you can do the following:
+1. Install [Docker](https://docs.docker.com/engine/install/), [Docker Compose plugin](https://docs.docker.com/compose/install/), and [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker)
+2. Clone this repo to, e.g., `~/stable-diffusion`
+3. Put your downloaded `model.ckpt` file into `~/sd-data` (it's a relative path, you can change it in `docker-compose.yml`)
+4. `cd` into `~/stable-diffusion` and execute `docker compose up --build`
+This will launch gradio on port 7860 with txt2img. You can also use `docker compose run` to execute other Python scripts.
+
+Usage
+
+## img2img
+
+- `img2img` can generate _512x512 images from a prior image and prompt using under 2.4GB VRAM in under 20 seconds per image_ on an RTX 2060.
+
+- The maximum size that can fit on 6GB GPU (RTX 2060) is around 1152x1088.
+
+- For example, the following command will generate 10 512x512 images:
+
+`python optimizedSD/optimized_img2img.py --prompt "Austrian alps" --init-img ~/sketch-mountains-input.jpg --strength 0.8 --n_iter 2 --n_samples 5 --H 512 --W 512`
+
+## txt2img
+
+- `txt2img` can generate _512x512 images from a prompt using under 2.4GB GPU VRAM in under 24 seconds per image_ on an RTX 2060.
+
+- For example, the following command will generate 10 512x512 images:
+
+`python optimizedSD/optimized_txt2img.py --prompt "Cyberpunk style image of a Tesla car reflection in rain" --H 512 --W 512 --seed 27 --n_iter 2 --n_samples 5 --ddim_steps 50`
+
+## inpainting
+
+- `inpaint_gradio.py` can fill masked parts of an image based on a given prompt. It can inpaint 512x512 images while using under 2.5GB of VRAM.
+
+- To launch the gradio interface for inpainting, run `python optimizedSD/inpaint_gradio.py`. The mask for the image can be drawn on the selected image using the brush tool.
+
+- The results are not yet perfect but can be improved by using a combination of prompt weighting, prompt engineering and testing out multiple values of the `--strength` argument.
+
+- _Suggestions to improve the inpainting algorithm are most welcome_.
+
+Using the Gradio GUI
+
+- You can also use the built-in gradio interface for `img2img`, `txt2img` & `inpainting` instead of the command line interface. Activate the conda environment and install the latest version of gradio using `pip install gradio`,
+
+- Run img2img using `python optimizedSD/img2img_gradio.py`, txt2img using `python optimizedSD/txt2img_gradio.py` and inpainting using `python optimizedSD/inpaint_gradio.py`.
+
+- img2img_gradio.py has a feature to crop input images. Look for the pen symbol in the image box after selecting the image.
+
+Arguments
+
+## `--seed`
+
+**Seed for image generation**, can be used to reproduce previously generated images. Defaults to a random seed if unspecified.
+
+- The code will give the seed number along with each generated image. To generate the same image again, just specify the seed using `--seed` argument. Images are saved with its seed number as its name by default.
+
+- For example if the seed number for an image is `1234` and it's the 55th image in the folder, the image name will be named `seed_1234_00055.png`.
+
+## `--n_samples`
+
+**Batch size/amount of images to generate at once.**
+
+- To get the lowest inference time per image, use the maximum batch size `--n_samples` that can fit on the GPU. Inference time per image will reduce on increasing the batch size, but the required VRAM will increase.
+
+- If you get a CUDA out of memory error, try reducing the batch size `--n_samples`. If it doesn't work, the other option is to reduce the image width `--W` or height `--H` or both.
+
+## `--n_iter`
+
+**Run _x_ amount of times**
+
+- Equivalent to running the script n_iter number of times. Only difference is that the model is loaded only once per n_iter iterations. Unlike `n_samples`, reducing it doesn't have an effect on VRAM required or inference time.
+
+## `--H` & `--W`
+
+**Height & width of the generated image.**
+
+- Both height and width should be a multiple of 64.
+
+## `--turbo`
+
+**Increases inference speed at the cost of extra VRAM usage.**
+
+- Using this argument increases the inference speed by using around 700MB of extra GPU VRAM. It is especially effective when generating a small batch of images (~ 1 to 4) images. It takes under 20 seconds for txt2img and 15 seconds for img2img (on an RTX 2060, excluding the time to load the model). Use it on larger batch sizes if GPU VRAM available.
+
+## `--precision autocast` or `--precision full`
+
+**Whether to use `full` or `mixed` precision**
+
+- Mixed Precision is enabled by default. If you don't have a GPU with tensor cores (any GTX 10 series card), you may not be able use mixed precision. Use the `--precision full` argument to disable it.
+
+## `--format png` or `--format jpg`
+
+**Output image format**
+
+- The default output format is `png`. While `png` is lossless, it takes up a lot of space (unless large portions of the image happen to be a single colour). Use lossy `jpg` to get smaller image file sizes.
+
+## `--unet_bs`
+
+**Batch size for the unet model**
+
+- Takes up a lot of extra RAM for **very little improvement** in inference time. `unet_bs` > 1 is not recommended!
+
+- Should generally be a multiple of 2x(n_samples)
+
+Weighted Prompts
+
+- Prompts can also be weighted to put relative emphasis on certain words.
+ eg. `--prompt tabby cat:0.25 white duck:0.75 hybrid`.
+
+- The number followed by the colon represents the weight given to the words before the colon. The weights can be both fractions or integers.
+
+## Troubleshooting
+
+### Green colored output images
+
+- If you have a Nvidia GTX series GPU, the output images maybe entirely green in color. This is because GTX series do not support half precision calculation, which is the default mode of calculation in this repository. To overcome the issue, use the `--precision full` argument. The downside is that it will lead to higher GPU VRAM usage.
+
+###
+
+## Changelog
+
+- v1.0: Added support for multiple samplers for txt2img. Based on [crowsonkb](https://github.com/crowsonkb/k-diffusion)
+- v0.9: Added support for calculating attention in parts. (Thanks to @neonsecret @Doggettx, @ryudrigo)
+- v0.8: Added gradio interface for inpainting.
+- v0.7: Added support for logging, jpg file format
+- v0.6: Added support for using weighted prompts. (based on @lstein's [repo](https://github.com/lstein/stable-diffusion))
+- v0.5: Added support for using gradio interface.
+- v0.4: Added support for specifying image seed.
+- v0.3: Added support for using mixed precision.
+- v0.2: Added support for generating images in batches.
+- v0.1: Split the model into multiple parts to run it on lower VRAM.
diff --git a/docker-bootstrap.sh b/docker-bootstrap.sh
new file mode 100755
index 000000000..d908d130d
--- /dev/null
+++ b/docker-bootstrap.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+set -e
+source /venv/bin/activate
+update-ca-certificates --fresh
+export SSL_CERT_DIR=/etc/ssl/certs
+exec "$@"
\ No newline at end of file
diff --git a/docker-compose.yml b/docker-compose.yml
new file mode 100644
index 000000000..66d76a264
--- /dev/null
+++ b/docker-compose.yml
@@ -0,0 +1,21 @@
+version: "3.9"
+services:
+ sd:
+ build: .
+ ports:
+ - "7860:7860"
+ volumes:
+ - ../sd-data:/data
+ - ../sd-output:/output
+ - sd-cache:/root/.cache
+ environment:
+ - APP_MAIN_FILE=txt2img_gradio.py
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: 1
+ capabilities: [gpu]
+volumes:
+ sd-cache:
diff --git a/environment.yaml b/environment.yaml
index 7f25da800..f41c3cada 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -8,7 +8,7 @@ dependencies:
- cudatoolkit=11.3
- pytorch=1.11.0
- torchvision=0.12.0
- - numpy=1.19.2
+ - numpy=1.20.3
- pip:
- albumentations==0.4.3
- opencv-python==4.1.2.30
diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py
index 533e589a2..cdddea776 100644
--- a/ldm/modules/diffusionmodules/model.py
+++ b/ldm/modules/diffusionmodules/model.py
@@ -2,6 +2,7 @@
import math
import torch
import torch.nn as nn
+from torch.nn.functional import silu
import numpy as np
from einops import rearrange
@@ -30,11 +31,6 @@ def get_timestep_embedding(timesteps, embedding_dim):
return emb
-def nonlinearity(x):
- # swish
- return x*torch.sigmoid(x)
-
-
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
@@ -121,14 +117,14 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
def forward(self, x, temb):
h = x
h = self.norm1(h)
- h = nonlinearity(h)
+ h = silu(h)
h = self.conv1(h)
if temb is not None:
- h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+ h = h + self.temb_proj(silu(temb))[:,:,None,None]
h = self.norm2(h)
- h = nonlinearity(h)
+ h = silu(h)
h = self.dropout(h)
h = self.conv2(h)
@@ -323,7 +319,7 @@ def forward(self, x, t=None, context=None):
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
- temb = nonlinearity(temb)
+ temb = silu(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
@@ -357,7 +353,7 @@ def forward(self, x, t=None, context=None):
# end
h = self.norm_out(h)
- h = nonlinearity(h)
+ h = silu(h)
h = self.conv_out(h)
return h
@@ -454,7 +450,7 @@ def forward(self, x):
# end
h = self.norm_out(h)
- h = nonlinearity(h)
+ h = silu(h)
h = self.conv_out(h)
return h
@@ -561,7 +557,7 @@ def forward(self, z):
return h
h = self.norm_out(h)
- h = nonlinearity(h)
+ h = silu(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
@@ -599,7 +595,7 @@ def forward(self, x):
x = layer(x)
h = self.norm_out(x)
- h = nonlinearity(h)
+ h = silu(h)
x = self.conv_out(h)
return x
@@ -647,7 +643,7 @@ def forward(self, x):
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
- h = nonlinearity(h)
+ h = silu(h)
h = self.conv_out(h)
return h
@@ -823,7 +819,7 @@ def forward(self,x):
z_fs = self.encode_with_pretrained(x)
z = self.proj_norm(z_fs)
z = self.proj(z)
- z = nonlinearity(z)
+ z = silu(z)
for submodel, downmodel in zip(self.model,self.downsampler):
z = submodel(z,temb=None)
diff --git a/optimizedSD/LICENSE b/optimizedSD/LICENSE
new file mode 100644
index 000000000..a91cdd652
--- /dev/null
+++ b/optimizedSD/LICENSE
@@ -0,0 +1,80 @@
+Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
+
+CreativeML Open RAIL-M
+dated August 22, 2022
+
+Section I: PREAMBLE
+
+Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
+
+Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
+
+In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
+
+Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
+
+This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
+
+NOW THEREFORE, You and Licensor agree as follows:
+
+1. Definitions
+
+- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
+- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
+- "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
+- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
+- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
+- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
+- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
+- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
+- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
+- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
+- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
+- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
+
+Section II: INTELLECTUAL PROPERTY RIGHTS
+
+Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
+
+2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
+3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
+
+Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
+
+4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
+ Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
+ You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
+ You must cause any modified files to carry prominent notices stating that You changed the files;
+ You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
+5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
+6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
+
+Section IV: OTHER PROVISIONS
+
+7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
+8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
+9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
+10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
+11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
+12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
+
+END OF TERMS AND CONDITIONS
+
+Attachment A
+
+Use Restrictions
+
+You agree not to use the Model or Derivatives of the Model:
+
+- In any way that violates any applicable national, federal, state, local or international law or regulation;
+- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
+- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
+- To generate or disseminate personal identifiable information that can be used to harm an individual;
+- To defame, disparage or otherwise harass others;
+- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
+- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
+- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
+- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
+- To provide medical advice and medical results interpretation;
+- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
new file mode 100644
index 000000000..b967b55ee
--- /dev/null
+++ b/optimizedSD/ddpm.py
@@ -0,0 +1,1030 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+import time, math
+from tqdm.auto import trange, tqdm
+import torch
+from einops import rearrange
+from tqdm import tqdm
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+from ldm.models.autoencoder import VQModelInterface
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from functools import partial
+from pytorch_lightning.utilities.distributed import rank_zero_only
+from ldm.util import exists, default, instantiate_from_config
+from ldm.modules.diffusionmodules.util import make_beta_schedule
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff
+
+def disabled_train(self):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class DDPM(pl.LightningModule):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(self,
+ timesteps=1000,
+ beta_schedule="linear",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.,
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ ):
+ super().__init__()
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
+ self.parameterization = parameterization
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.image_size = image_size # try conv?
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+
+
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+
+
+class FirstStage(DDPM):
+ """main class"""
+ def __init__(self,
+ first_stage_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_by_std=False,
+ *args, **kwargs):
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__()
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ self.instantiate_first_stage(first_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+ return self.scale_factor * z
+
+
+ @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z
+
+ if hasattr(self, "split_input_params"):
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ df = self.split_input_params["vqf"]
+ self.split_input_params['original_image_size'] = x.shape[-2:]
+ bs, nc, h, w = x.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
+ z = unfold(x) # (bn, nc * prod(**ks), L)
+ # Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1)
+ o = o * weighting
+
+ # Reverse reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization
+ return decoded
+
+ else:
+ return self.first_stage_model.encode(x)
+ else:
+ return self.first_stage_model.encode(x)
+
+
+class CondStage(DDPM):
+ """main class"""
+ def __init__(self,
+ cond_stage_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_by_std=False,
+ *args, **kwargs):
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ if cond_stage_config == '__is_unconditional__':
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__()
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ self.instantiate_cond_stage(cond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ assert config != '__is_first_stage__'
+ assert config != '__is_unconditional__'
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+class DiffusionWrapper(pl.LightningModule):
+ def __init__(self, diff_model_config):
+ super().__init__()
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+
+ def forward(self, x, t, cc):
+ out = self.diffusion_model(x, t, context=cc)
+ return out
+
+class DiffusionWrapperOut(pl.LightningModule):
+ def __init__(self, diff_model_config):
+ super().__init__()
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+
+ def forward(self, h,emb,tp,hs, cc):
+ return self.diffusion_model(h,emb,tp,hs, context=cc)
+
+
+class UNet(DDPM):
+ """main class"""
+ def __init__(self,
+ unetConfigEncode,
+ unetConfigDecode,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ unet_bs = 1,
+ scale_by_std=False,
+ *args, **kwargs):
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ self.num_downs = 0
+ self.cdevice = "cuda"
+ self.unetConfigEncode = unetConfigEncode
+ self.unetConfigDecode = unetConfigDecode
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+ self.model1 = DiffusionWrapper(self.unetConfigEncode)
+ self.model2 = DiffusionWrapperOut(self.unetConfigDecode)
+ self.model1.eval()
+ self.model2.eval()
+ self.turbo = False
+ self.unet_bs = unet_bs
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+
+ def make_cond_schedule(self, ):
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+ self.cond_ids[:self.num_timesteps_cond] = ids
+
+ @rank_zero_only
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx):
+ # only for very first batch
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.cdevice)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ print("### USING STD-RESCALING ###")
+
+
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
+
+ if(not self.turbo):
+ self.model1.to(self.cdevice)
+
+ step = self.unet_bs
+ h,emb,hs = self.model1(x_noisy[0:step], t[:step], cond[:step])
+ bs = cond.shape[0]
+
+ # assert bs%2 == 0
+ lenhs = len(hs)
+
+ for i in range(step,bs,step):
+ h_temp,emb_temp,hs_temp = self.model1(x_noisy[i:i+step], t[i:i+step], cond[i:i+step])
+ h = torch.cat((h,h_temp))
+ emb = torch.cat((emb,emb_temp))
+ for j in range(lenhs):
+ hs[j] = torch.cat((hs[j], hs_temp[j]))
+
+
+ if(not self.turbo):
+ self.model1.to("cpu")
+ self.model2.to(self.cdevice)
+
+ hs_temp = [hs[j][:step] for j in range(lenhs)]
+ x_recon = self.model2(h[:step],emb[:step],x_noisy.dtype,hs_temp,cond[:step])
+
+ for i in range(step,bs,step):
+
+ hs_temp = [hs[j][i:i+step] for j in range(lenhs)]
+ x_recon1 = self.model2(h[i:i+step],emb[i:i+step],x_noisy.dtype,hs_temp,cond[i:i+step])
+ x_recon = torch.cat((x_recon, x_recon1))
+
+ if(not self.turbo):
+ self.model2.to("cpu")
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def register_buffer1(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device(self.cdevice):
+ attr = attr.to(torch.device(self.cdevice))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+
+
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.num_timesteps,verbose=verbose)
+
+
+ assert self.alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+
+ to_torch = lambda x: x.to(self.cdevice)
+ self.register_buffer1('betas', to_torch(self.betas))
+ self.register_buffer1('alphas_cumprod', to_torch(self.alphas_cumprod))
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=self.alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer1('ddim_sigmas', ddim_sigmas)
+ self.register_buffer1('ddim_alphas', ddim_alphas)
+ self.register_buffer1('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer1('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ conditioning,
+ x0=None,
+ shape = None,
+ seed=1234,
+ callback=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ sampler = "plms",
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ ):
+
+
+ if(self.turbo):
+ self.model1.to(self.cdevice)
+ self.model2.to(self.cdevice)
+
+ if x0 is None:
+ batch_size, b1, b2, b3 = shape
+ img_shape = (1, b1, b2, b3)
+ tens = []
+ print("seeds used = ", [seed+s for s in range(batch_size)])
+ for _ in range(batch_size):
+ torch.manual_seed(seed)
+ tens.append(torch.randn(img_shape, device=self.cdevice))
+ seed+=1
+ noise = torch.cat(tens)
+ del tens
+
+ x_latent = noise if x0 is None else x0
+ # sampling
+
+ if sampler == "plms":
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
+ print(f'Data shape for PLMS sampling is {shape}')
+ samples = self.plms_sampling(conditioning, batch_size, x_latent,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+
+ elif sampler == "ddim":
+ samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ mask = mask,init_latent=x_T,use_original_steps=False)
+
+ elif sampler == "euler":
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
+ samples = self.euler_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
+ unconditional_guidance_scale=unconditional_guidance_scale)
+ elif sampler == "euler_a":
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
+ samples = self.euler_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
+ unconditional_guidance_scale=unconditional_guidance_scale)
+
+ elif sampler == "dpm2":
+ samples = self.dpm_2_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
+ unconditional_guidance_scale=unconditional_guidance_scale)
+ elif sampler == "heun":
+ samples = self.heun_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
+ unconditional_guidance_scale=unconditional_guidance_scale)
+
+ elif sampler == "dpm2_a":
+ samples = self.dpm_2_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
+ unconditional_guidance_scale=unconditional_guidance_scale)
+
+
+ elif sampler == "lms":
+ samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
+ unconditional_guidance_scale=unconditional_guidance_scale)
+
+ if(self.turbo):
+ self.model1.to("cpu")
+ self.model2.to("cpu")
+
+ return samples
+
+ @torch.no_grad()
+ def plms_sampling(self, cond,b, img,
+ ddim_use_original_steps=False,
+ callback=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
+
+ device = self.betas.device
+ timesteps = self.ddim_timesteps
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
+ old_eps = []
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ old_eps=old_eps, t_next=ts_next)
+ img, pred_x0, e_t = outs
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ return img
+
+ @torch.no_grad()
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ return e_t
+
+ alphas = self.ddim_alphas
+ alphas_prev = self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
+ sigmas = self.ddim_sigmas
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
+
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ self.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+
+ if noise is None:
+ b0, b1, b2, b3 = x0.shape
+ img_shape = (1, b1, b2, b3)
+ tens = []
+ print("seeds used = ", [seed+s for s in range(b0)])
+ for _ in range(b0):
+ torch.manual_seed(seed)
+ tens.append(torch.randn(img_shape, device=x0.device))
+ seed+=1
+ noise = torch.cat(tens)
+ del tens
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise)
+
+ @torch.no_grad()
+ def add_noise(self, x0, t):
+
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ noise = torch.randn(x0.shape, device=x0.device)
+
+ # print(extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape),
+ # extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape))
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise)
+
+
+ @torch.no_grad()
+ def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+ mask = None,init_latent=None,use_original_steps=False):
+
+ timesteps = self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ x0 = init_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+
+ if mask is not None:
+ # x0_noisy = self.add_noise(mask, torch.tensor([index] * x0.shape[0]).to(self.cdevice))
+ x0_noisy = x0
+ x_dec = x0_noisy* mask + (1. - mask) * x_dec
+
+ x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+
+ if mask is not None:
+ return x0 * mask + (1. - mask) * x_dec
+
+ return x_dec
+
+
+ @torch.no_grad()
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.ddim_alphas
+ alphas_prev = self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
+ sigmas = self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev
+
+
+ @torch.no_grad()
+ def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
+ extra_args = {} if extra_args is None else extra_args
+ cvd = CompVisDenoiser(ac)
+ sigmas = cvd.get_sigmas(S)
+ x = x*sigmas[0]
+
+ s_in = x.new_ones([x.shape[0]]).half()
+ for i in trange(len(sigmas) - 1, disable=disable):
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
+ eps = torch.randn_like(x) * s_noise
+ sigma_hat = (sigmas[i] * (gamma + 1)).half()
+ if gamma > 0:
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
+
+ s_i = sigma_hat * s_in
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([s_i] * 2)
+ cond_in = torch.cat([unconditional_conditioning, cond])
+ c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
+ eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
+ e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
+ denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+
+ d = to_d(x, sigma_hat, denoised)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
+ dt = sigmas[i + 1] - sigma_hat
+ # Euler method
+ x = x + d * dt
+ return x
+
+ @torch.no_grad()
+ def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None):
+ """Ancestral sampling with Euler method steps."""
+ extra_args = {} if extra_args is None else extra_args
+
+
+ cvd = CompVisDenoiser(ac)
+ sigmas = cvd.get_sigmas(S)
+ x = x*sigmas[0]
+
+ s_in = x.new_ones([x.shape[0]]).half()
+ for i in trange(len(sigmas) - 1, disable=disable):
+
+ s_i = sigmas[i] * s_in
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([s_i] * 2)
+ cond_in = torch.cat([unconditional_conditioning, cond])
+ c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
+ eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
+ e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
+ denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ d = to_d(x, sigmas[i], denoised)
+ # Euler method
+ dt = sigma_down - sigmas[i]
+ x = x + d * dt
+ x = x + torch.randn_like(x) * sigma_up
+ return x
+
+
+
+ @torch.no_grad()
+ def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
+ extra_args = {} if extra_args is None else extra_args
+
+ cvd = CompVisDenoiser(alphas_cumprod=ac)
+ sigmas = cvd.get_sigmas(S)
+ x = x*sigmas[0]
+
+
+ s_in = x.new_ones([x.shape[0]]).half()
+ for i in trange(len(sigmas) - 1, disable=disable):
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
+ eps = torch.randn_like(x) * s_noise
+ sigma_hat = (sigmas[i] * (gamma + 1)).half()
+ if gamma > 0:
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
+
+ s_i = sigma_hat * s_in
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([s_i] * 2)
+ cond_in = torch.cat([unconditional_conditioning, cond])
+ c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
+ eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
+ e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
+ denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ d = to_d(x, sigma_hat, denoised)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
+ dt = sigmas[i + 1] - sigma_hat
+ if sigmas[i + 1] == 0:
+ # Euler method
+ x = x + d * dt
+ else:
+ # Heun's method
+ x_2 = x + d * dt
+ s_i = sigmas[i + 1] * s_in
+ x_in = torch.cat([x_2] * 2)
+ t_in = torch.cat([s_i] * 2)
+ cond_in = torch.cat([unconditional_conditioning, cond])
+ c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
+ eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
+ e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
+ denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
+ d_prime = (d + d_2) / 2
+ x = x + d_prime * dt
+ return x
+
+
+ @torch.no_grad()
+ def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
+ extra_args = {} if extra_args is None else extra_args
+
+ cvd = CompVisDenoiser(ac)
+ sigmas = cvd.get_sigmas(S)
+ x = x*sigmas[0]
+
+ s_in = x.new_ones([x.shape[0]]).half()
+ for i in trange(len(sigmas) - 1, disable=disable):
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
+ eps = torch.randn_like(x) * s_noise
+ sigma_hat = sigmas[i] * (gamma + 1)
+ if gamma > 0:
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
+
+ s_i = sigma_hat * s_in
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([s_i] * 2)
+ cond_in = torch.cat([unconditional_conditioning, cond])
+ c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
+ eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
+ e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
+ denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+
+
+ d = to_d(x, sigma_hat, denoised)
+ # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
+ sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
+ dt_1 = sigma_mid - sigma_hat
+ dt_2 = sigmas[i + 1] - sigma_hat
+ x_2 = x + d * dt_1
+
+ s_i = sigma_mid * s_in
+ x_in = torch.cat([x_2] * 2)
+ t_in = torch.cat([s_i] * 2)
+ cond_in = torch.cat([unconditional_conditioning, cond])
+ c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
+ eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
+ e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
+ denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
+ x = x + d_2 * dt_2
+ return x
+
+
+ @torch.no_grad()
+ def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None):
+ """Ancestral sampling with DPM-Solver inspired second-order steps."""
+ extra_args = {} if extra_args is None else extra_args
+
+ cvd = CompVisDenoiser(ac)
+ sigmas = cvd.get_sigmas(S)
+ x = x*sigmas[0]
+
+ s_in = x.new_ones([x.shape[0]]).half()
+ for i in trange(len(sigmas) - 1, disable=disable):
+
+ s_i = sigmas[i] * s_in
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([s_i] * 2)
+ cond_in = torch.cat([unconditional_conditioning, cond])
+ c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
+ eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
+ e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
+ denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ d = to_d(x, sigmas[i], denoised)
+ # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
+ sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
+ dt_1 = sigma_mid - sigmas[i]
+ dt_2 = sigma_down - sigmas[i]
+ x_2 = x + d * dt_1
+
+ s_i = sigma_mid * s_in
+ x_in = torch.cat([x_2] * 2)
+ t_in = torch.cat([s_i] * 2)
+ cond_in = torch.cat([unconditional_conditioning, cond])
+ c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
+ eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
+ e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
+ denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
+ x = x + d_2 * dt_2
+ x = x + torch.randn_like(x) * sigma_up
+ return x
+
+
+ @torch.no_grad()
+ def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4):
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+
+ cvd = CompVisDenoiser(ac)
+ sigmas = cvd.get_sigmas(S)
+ x = x*sigmas[0]
+
+ ds = []
+ for i in trange(len(sigmas) - 1, disable=disable):
+
+ s_i = sigmas[i] * s_in
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([s_i] * 2)
+ cond_in = torch.cat([unconditional_conditioning, cond])
+ c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
+ eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
+ e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
+ denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+
+ d = to_d(x, sigmas[i], denoised)
+ ds.append(d)
+ if len(ds) > order:
+ ds.pop(0)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ cur_order = min(i + 1, order)
+ coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
+ return x
diff --git a/optimizedSD/diffusers_txt2img.py b/optimizedSD/diffusers_txt2img.py
new file mode 100644
index 000000000..80fbb9723
--- /dev/null
+++ b/optimizedSD/diffusers_txt2img.py
@@ -0,0 +1,13 @@
+import torch
+from diffusers import LDMTextToImagePipeline
+
+pipe = LDMTextToImagePipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=True)
+
+prompt = "19th Century wooden engraving of Elon musk"
+
+seed = torch.manual_seed(1024)
+images = pipe([prompt], batch_size=1, num_inference_steps=50, guidance_scale=7, generator=seed,torch_device="cpu" )["sample"]
+
+# save images
+for idx, image in enumerate(images):
+ image.save(f"image-{idx}.png")
diff --git a/optimizedSD/img2img_gradio.py b/optimizedSD/img2img_gradio.py
new file mode 100644
index 000000000..65d844d3b
--- /dev/null
+++ b/optimizedSD/img2img_gradio.py
@@ -0,0 +1,283 @@
+import gradio as gr
+import numpy as np
+import torch
+from torchvision.utils import make_grid
+import os, re
+from PIL import Image
+import torch
+import numpy as np
+from random import randint
+from omegaconf import OmegaConf
+from PIL import Image
+from tqdm import tqdm, trange
+from itertools import islice
+from einops import rearrange
+from torchvision.utils import make_grid
+import time
+from pytorch_lightning import seed_everything
+from torch import autocast
+from einops import rearrange, repeat
+from contextlib import nullcontext
+from ldm.util import instantiate_from_config
+from transformers import logging
+import pandas as pd
+from optimUtils import split_weighted_subprompts, logger
+logging.set_verbosity_error()
+import mimetypes
+mimetypes.init()
+mimetypes.add_type("application/javascript", ".js")
+
+
+def chunk(it, size):
+ it = iter(it)
+ return iter(lambda: tuple(islice(it, size)), ())
+
+
+def load_model_from_config(ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ return sd
+
+
+def load_img(image, h0, w0):
+
+ image = image.convert("RGB")
+ w, h = image.size
+ print(f"loaded input image of size ({w}, {h})")
+ if h0 is not None and w0 is not None:
+ h, w = h0, w0
+
+ w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
+
+ print(f"New image size ({w}, {h})")
+ image = image.resize((w, h), resample=Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+config = "optimizedSD/v1-inference.yaml"
+ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
+sd = load_model_from_config(f"{ckpt}")
+li, lo = [], []
+for key, v_ in sd.items():
+ sp = key.split(".")
+ if (sp[0]) == "model":
+ if "input_blocks" in sp:
+ li.append(key)
+ elif "middle_block" in sp:
+ li.append(key)
+ elif "time_embed" in sp:
+ li.append(key)
+ else:
+ lo.append(key)
+for key in li:
+ sd["model1." + key[6:]] = sd.pop(key)
+for key in lo:
+ sd["model2." + key[6:]] = sd.pop(key)
+
+config = OmegaConf.load(f"{config}")
+
+model = instantiate_from_config(config.modelUNet)
+_, _ = model.load_state_dict(sd, strict=False)
+model.eval()
+
+modelCS = instantiate_from_config(config.modelCondStage)
+_, _ = modelCS.load_state_dict(sd, strict=False)
+modelCS.eval()
+
+modelFS = instantiate_from_config(config.modelFirstStage)
+_, _ = modelFS.load_state_dict(sd, strict=False)
+modelFS.eval()
+del sd
+
+def generate(
+ image,
+ prompt,
+ strength,
+ ddim_steps,
+ n_iter,
+ batch_size,
+ Height,
+ Width,
+ scale,
+ ddim_eta,
+ unet_bs,
+ device,
+ seed,
+ outdir,
+ img_format,
+ turbo,
+ full_precision,
+):
+
+ if seed == "":
+ seed = randint(0, 1000000)
+ seed = int(seed)
+ seed_everything(seed)
+
+ # Logging
+ sampler = "ddim"
+ logger(locals(), log_csv = "logs/img2img_gradio_logs.csv")
+
+ init_image = load_img(image, Height, Width).to(device)
+ model.unet_bs = unet_bs
+ model.turbo = turbo
+ model.cdevice = device
+ modelCS.cond_stage_model.device = device
+
+ if device != "cpu" and full_precision == False:
+ model.half()
+ modelCS.half()
+ modelFS.half()
+ init_image = init_image.half()
+
+ tic = time.time()
+ os.makedirs(outdir, exist_ok=True)
+ outpath = outdir
+ sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompt)))[:150]
+ os.makedirs(sample_path, exist_ok=True)
+ base_count = len(os.listdir(sample_path))
+
+ # n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
+ assert prompt is not None
+ data = [batch_size * [prompt]]
+
+ modelFS.to(device)
+
+ init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
+ init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
+
+ if device != "cpu":
+ mem = torch.cuda.memory_allocated() / 1e6
+ modelFS.to("cpu")
+ while torch.cuda.memory_allocated() / 1e6 >= mem:
+ time.sleep(1)
+
+ assert 0.0 <= strength <= 1.0, "can only work with strength in [0.0, 1.0]"
+ t_enc = int(strength * ddim_steps)
+ print(f"target t_enc is {t_enc} steps")
+
+ if full_precision == False and device != "cpu":
+ precision_scope = autocast
+ else:
+ precision_scope = nullcontext
+
+ all_samples = []
+ seeds = ""
+ with torch.no_grad():
+ all_samples = list()
+ for _ in trange(n_iter, desc="Sampling"):
+ for prompts in tqdm(data, desc="data"):
+ with precision_scope("cuda"):
+ modelCS.to(device)
+ uc = None
+ if scale != 1.0:
+ uc = modelCS.get_learned_conditioning(batch_size * [""])
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+
+ subprompts, weights = split_weighted_subprompts(prompts[0])
+ if len(subprompts) > 1:
+ c = torch.zeros_like(uc)
+ totalWeight = sum(weights)
+ # normalize each "sub prompt" and add it
+ for i in range(len(subprompts)):
+ weight = weights[i]
+ # if not skip_normalize:
+ weight = weight / totalWeight
+ c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
+ else:
+ c = modelCS.get_learned_conditioning(prompts)
+
+ if device != "cpu":
+ mem = torch.cuda.memory_allocated() / 1e6
+ modelCS.to("cpu")
+ while torch.cuda.memory_allocated() / 1e6 >= mem:
+ time.sleep(1)
+
+ # encode (scaled latent)
+ z_enc = model.stochastic_encode(
+ init_latent, torch.tensor([t_enc] * batch_size).to(device), seed, ddim_eta, ddim_steps
+ )
+ # decode it
+ samples_ddim = model.sample(
+ t_enc,
+ c,
+ z_enc,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=uc,
+ sampler = sampler
+ )
+
+ modelFS.to(device)
+ print("saving images")
+ for i in range(batch_size):
+
+ x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
+ x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ all_samples.append(x_sample.to("cpu"))
+ x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
+ Image.fromarray(x_sample.astype(np.uint8)).save(
+ os.path.join(sample_path, "seed_" + str(seed) + "_" + f"{base_count:05}.{img_format}")
+ )
+ seeds += str(seed) + ","
+ seed += 1
+ base_count += 1
+
+ if device != "cpu":
+ mem = torch.cuda.memory_allocated() / 1e6
+ modelFS.to("cpu")
+ while torch.cuda.memory_allocated() / 1e6 >= mem:
+ time.sleep(1)
+
+ del samples_ddim
+ del x_sample
+ del x_samples_ddim
+ print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
+
+ toc = time.time()
+
+ time_taken = (toc - tic) / 60.0
+ grid = torch.cat(all_samples, 0)
+ grid = make_grid(grid, nrow=n_iter)
+ grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
+
+ txt = (
+ "Samples finished in "
+ + str(round(time_taken, 3))
+ + " minutes and exported to \n"
+ + sample_path
+ + "\nSeeds used = "
+ + seeds[:-1]
+ )
+ return Image.fromarray(grid.astype(np.uint8)), txt
+
+
+demo = gr.Interface(
+ fn=generate,
+ inputs=[
+ gr.Image(tool="editor", type="pil"),
+ "text",
+ gr.Slider(0, 1, value=0.75),
+ gr.Slider(1, 1000, value=50),
+ gr.Slider(1, 100, step=1),
+ gr.Slider(1, 100, step=1),
+ gr.Slider(64, 4096, value=512, step=64),
+ gr.Slider(64, 4096, value=512, step=64),
+ gr.Slider(0, 50, value=7.5, step=0.1),
+ gr.Slider(0, 1, step=0.01),
+ gr.Slider(1, 2, value=1, step=1),
+ gr.Text(value="cuda"),
+ "text",
+ gr.Text(value="outputs/img2img-samples"),
+ gr.Radio(["png", "jpg"], value='png'),
+ "checkbox",
+ "checkbox",
+ ],
+ outputs=["image", "text"],
+)
+demo.launch()
diff --git a/optimizedSD/inpaint_gradio.py b/optimizedSD/inpaint_gradio.py
new file mode 100644
index 000000000..4af37c23f
--- /dev/null
+++ b/optimizedSD/inpaint_gradio.py
@@ -0,0 +1,328 @@
+import argparse
+import os
+import re
+import time
+from contextlib import nullcontext
+from itertools import islice
+from random import randint
+
+import gradio as gr
+import numpy as np
+import torch
+from PIL import Image
+from einops import rearrange, repeat
+from omegaconf import OmegaConf
+from pytorch_lightning import seed_everything
+from torch import autocast
+from torchvision.utils import make_grid
+from tqdm import tqdm, trange
+from transformers import logging
+
+from ldm.util import instantiate_from_config
+from optimUtils import split_weighted_subprompts, logger
+
+logging.set_verbosity_error()
+import mimetypes
+
+mimetypes.init()
+mimetypes.add_type("application/javascript", ".js")
+
+
+def chunk(it, size):
+ it = iter(it)
+ return iter(lambda: tuple(islice(it, size)), ())
+
+
+def load_model_from_config(ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ return sd
+
+
+def load_img(image, h0, w0):
+ image = image.convert("RGB")
+ w, h = image.size
+ print(f"loaded input image of size ({w}, {h})")
+ if h0 is not None and w0 is not None:
+ h, w = h0, w0
+
+ w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
+
+ print(f"New image size ({w}, {h})")
+ image = image.resize((w, h), resample=Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+def load_mask(mask, h0, w0, newH, newW, invert=False):
+ image = mask.convert("RGB")
+ w, h = image.size
+ print(f"loaded input mask of size ({w}, {h})")
+ if h0 is not None and w0 is not None:
+ h, w = h0, w0
+
+ w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
+
+ print(f"New mask size ({w}, {h})")
+ image = image.resize((newW, newH), resample=Image.LANCZOS)
+ # image = image.resize((64, 64), resample=Image.LANCZOS)
+ image = np.array(image)
+
+ if invert:
+ print("inverted")
+ where_0, where_1 = np.where(image == 0), np.where(image == 255)
+ image[where_0], image[where_1] = 255, 0
+ image = image.astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return image
+
+
+def generate(
+ image,
+ mask_image,
+ prompt,
+ strength,
+ ddim_steps,
+ n_iter,
+ batch_size,
+ Height,
+ Width,
+ scale,
+ ddim_eta,
+ unet_bs,
+ device,
+ seed,
+ outdir,
+ img_format,
+ turbo,
+ full_precision,
+):
+ if seed == "":
+ seed = randint(0, 1000000)
+ seed = int(seed)
+ seed_everything(seed)
+ sampler = "ddim"
+
+ # Logging
+ logger(locals(), log_csv="logs/inpaint_gradio_logs.csv")
+
+ init_image = load_img(image['image'], Height, Width).to(device)
+
+ model.unet_bs = unet_bs
+ model.turbo = turbo
+ model.cdevice = device
+ modelCS.cond_stage_model.device = device
+
+ if device != "cpu" and full_precision == False:
+ model.half()
+ modelCS.half()
+ modelFS.half()
+ init_image = init_image.half()
+ # mask.half()
+
+ tic = time.time()
+ os.makedirs(outdir, exist_ok=True)
+ outpath = outdir
+ sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompt)))[:150]
+ os.makedirs(sample_path, exist_ok=True)
+ base_count = len(os.listdir(sample_path))
+
+ # n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
+ assert prompt is not None
+ data = [batch_size * [prompt]]
+
+ modelFS.to(device)
+
+ init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
+ init_latent = repeat(init_latent, "1 ... -> b ...", b=batch_size)
+ if mask_image is None:
+ mask = load_mask(image['mask'], Height, Width, init_latent.shape[2], init_latent.shape[3], True).to(device)
+ else:
+ image['mask']=mask_image
+ mask = load_mask(mask_image, Height, Width, init_latent.shape[2], init_latent.shape[3], True).to(device)
+
+ mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
+ mask = repeat(mask, '1 ... -> b ...', b=batch_size)
+
+ if device != "cpu":
+ mem = torch.cuda.memory_allocated() / 1e6
+ modelFS.to("cpu")
+ while torch.cuda.memory_allocated() / 1e6 >= mem:
+ time.sleep(1)
+
+ if strength == 1:
+ print("strength should be less than 1, setting it to 0.999")
+ strength = 0.999
+ assert 0.0 <= strength < 1.0, "can only work with strength in [0.0, 1.0]"
+ t_enc = int(strength * ddim_steps)
+ print(f"target t_enc is {t_enc} steps")
+
+ if full_precision == False and device != "cpu":
+ precision_scope = autocast
+ else:
+ precision_scope = nullcontext
+
+ all_samples = []
+ seeds = ""
+ with torch.no_grad():
+ all_samples = list()
+ for _ in trange(n_iter, desc="Sampling"):
+ for prompts in tqdm(data, desc="data"):
+ with precision_scope("cuda"):
+ modelCS.to(device)
+ uc = None
+ if scale != 1.0:
+ uc = modelCS.get_learned_conditioning(batch_size * [""])
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+
+ subprompts, weights = split_weighted_subprompts(prompts[0])
+ if len(subprompts) > 1:
+ c = torch.zeros_like(uc)
+ totalWeight = sum(weights)
+ # normalize each "sub prompt" and add it
+ for i in range(len(subprompts)):
+ weight = weights[i]
+ # if not skip_normalize:
+ weight = weight / totalWeight
+ c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
+ else:
+ c = modelCS.get_learned_conditioning(prompts)
+
+ if device != "cpu":
+ mem = torch.cuda.memory_allocated() / 1e6
+ modelCS.to("cpu")
+ while torch.cuda.memory_allocated() / 1e6 >= mem:
+ time.sleep(1)
+
+ # encode (scaled latent)
+ z_enc = model.stochastic_encode(
+ init_latent, torch.tensor([t_enc] * batch_size).to(device),
+ seed, ddim_eta, ddim_steps)
+
+ # decode it
+ samples_ddim = model.sample(
+ t_enc,
+ c,
+ z_enc,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=uc,
+ mask=mask,
+ x_T=init_latent,
+ sampler=sampler,
+ )
+
+ modelFS.to(device)
+ print("saving images")
+ for i in range(batch_size):
+ x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
+ x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ all_samples.append(x_sample.to("cpu"))
+ x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
+ Image.fromarray(x_sample.astype(np.uint8)).save(
+ os.path.join(sample_path, "seed_" + str(seed) + "_" + f"{base_count:05}.{img_format}")
+ )
+ seeds += str(seed) + ","
+ seed += 1
+ base_count += 1
+
+ if device != "cpu":
+ mem = torch.cuda.memory_allocated() / 1e6
+ modelFS.to("cpu")
+ while torch.cuda.memory_allocated() / 1e6 >= mem:
+ time.sleep(1)
+
+ del samples_ddim
+ del x_sample
+ del x_samples_ddim
+ print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
+
+ toc = time.time()
+
+ time_taken = (toc - tic) / 60.0
+ grid = torch.cat(all_samples, 0)
+ grid = make_grid(grid, nrow=n_iter)
+ grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
+
+ txt = (
+ "Samples finished in "
+ + str(round(time_taken, 3))
+ + " minutes and exported to \n"
+ + sample_path
+ + "\nSeeds used = "
+ + seeds[:-1]
+ )
+ return Image.fromarray(grid.astype(np.uint8)), image['mask'], txt
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='txt2img using gradio')
+ parser.add_argument('--config_path', default="optimizedSD/v1-inference.yaml", type=str, help='config path')
+ parser.add_argument('--ckpt_path', default="models/ldm/stable-diffusion-v1/model.ckpt", type=str, help='ckpt path')
+ args = parser.parse_args()
+ config = args.config_path
+ ckpt = args.ckpt_path
+ sd = load_model_from_config(f"{ckpt}")
+ li, lo = [], []
+ for key, v_ in sd.items():
+ sp = key.split(".")
+ if (sp[0]) == "model":
+ if "input_blocks" in sp:
+ li.append(key)
+ elif "middle_block" in sp:
+ li.append(key)
+ elif "time_embed" in sp:
+ li.append(key)
+ else:
+ lo.append(key)
+ for key in li:
+ sd["model1." + key[6:]] = sd.pop(key)
+ for key in lo:
+ sd["model2." + key[6:]] = sd.pop(key)
+
+ config = OmegaConf.load(f"{config}")
+
+ model = instantiate_from_config(config.modelUNet)
+ _, _ = model.load_state_dict(sd, strict=False)
+ model.eval()
+
+ modelCS = instantiate_from_config(config.modelCondStage)
+ _, _ = modelCS.load_state_dict(sd, strict=False)
+ modelCS.eval()
+
+ modelFS = instantiate_from_config(config.modelFirstStage)
+ _, _ = modelFS.load_state_dict(sd, strict=False)
+ modelFS.eval()
+ del sd
+
+ demo = gr.Interface(
+ fn=generate,
+ inputs=[
+ gr.Image(tool="sketch", type="pil"),
+ gr.Image(tool="editor", type="pil"),
+ "text",
+ gr.Slider(0, 0.99, value=0.99, step=0.01),
+ gr.Slider(1, 1000, value=50),
+ gr.Slider(1, 100, step=1),
+ gr.Slider(1, 100, step=1),
+ gr.Slider(64, 4096, value=512, step=64),
+ gr.Slider(64, 4096, value=512, step=64),
+ gr.Slider(0, 50, value=7.5, step=0.1),
+ gr.Slider(0, 1, step=0.01),
+ gr.Slider(1, 2, value=1, step=1),
+ gr.Text(value="cuda"),
+ "text",
+ gr.Text(value="outputs/inpaint-samples"),
+ gr.Radio(["png", "jpg"], value='png'),
+ "checkbox",
+ "checkbox",
+ ],
+ outputs=["image", "image", "text"],
+ )
+ demo.launch()
diff --git a/optimizedSD/openaimodelSplit.py b/optimizedSD/openaimodelSplit.py
new file mode 100644
index 000000000..abc3098b4
--- /dev/null
+++ b/optimizedSD/openaimodelSplit.py
@@ -0,0 +1,807 @@
+from abc import abstractmethod
+import math
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from ldm.modules.diffusionmodules.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from splitAttention import SpatialTransformer
+
+
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+ def forward(self,x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModelEncode(nn.Module):
+
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ def forward(self, x, timesteps=None, context=None, y=None):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+
+ return h, emb, hs
+
+
+class UNetModelDecode(nn.Module):
+
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def forward(self, h,emb,tp,hs, context=None, y=None):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(tp)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
\ No newline at end of file
diff --git a/optimizedSD/optimUtils.py b/optimizedSD/optimUtils.py
new file mode 100644
index 000000000..18b996792
--- /dev/null
+++ b/optimizedSD/optimUtils.py
@@ -0,0 +1,73 @@
+import os
+import pandas as pd
+
+
+def split_weighted_subprompts(text):
+ """
+ grabs all text up to the first occurrence of ':'
+ uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
+ if ':' has no value defined, defaults to 1.0
+ repeats until no text remaining
+ """
+ remaining = len(text)
+ prompts = []
+ weights = []
+ while remaining > 0:
+ if ":" in text:
+ idx = text.index(":") # first occurrence from start
+ # grab up to index as sub-prompt
+ prompt = text[:idx]
+ remaining -= idx
+ # remove from main text
+ text = text[idx+1:]
+ # find value for weight
+ if " " in text:
+ idx = text.index(" ") # first occurence
+ else: # no space, read to end
+ idx = len(text)
+ if idx != 0:
+ try:
+ weight = float(text[:idx])
+ except: # couldn't treat as float
+ print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
+ weight = 1.0
+ else: # no value found
+ weight = 1.0
+ # remove from main text
+ remaining -= idx
+ text = text[idx+1:]
+ # append the sub-prompt and its weight
+ prompts.append(prompt)
+ weights.append(weight)
+ else: # no : found
+ if len(text) > 0: # there is still text though
+ # take remainder as weight 1
+ prompts.append(text)
+ weights.append(1.0)
+ remaining = 0
+ return prompts, weights
+
+def logger(params, log_csv):
+ os.makedirs('logs', exist_ok=True)
+ cols = [arg for arg, _ in params.items()]
+ if not os.path.exists(log_csv):
+ df = pd.DataFrame(columns=cols)
+ df.to_csv(log_csv, index=False)
+
+ df = pd.read_csv(log_csv)
+ for arg in cols:
+ if arg not in df.columns:
+ df[arg] = ""
+ df.to_csv(log_csv, index = False)
+
+ li = {}
+ cols = [col for col in df.columns]
+ data = {arg:value for arg, value in params.items()}
+ for col in cols:
+ if col in data:
+ li[col] = data[col]
+ else:
+ li[col] = ''
+
+ df = pd.DataFrame(li,index = [0])
+ df.to_csv(log_csv,index=False, mode='a', header=False)
\ No newline at end of file
diff --git a/optimizedSD/optimized_img2img.py b/optimizedSD/optimized_img2img.py
new file mode 100644
index 000000000..24f3338f0
--- /dev/null
+++ b/optimizedSD/optimized_img2img.py
@@ -0,0 +1,362 @@
+import argparse, os, re
+import torch
+import numpy as np
+from random import randint
+from omegaconf import OmegaConf
+from PIL import Image
+from tqdm import tqdm, trange
+from itertools import islice
+from einops import rearrange
+from torchvision.utils import make_grid
+import time
+from pytorch_lightning import seed_everything
+from torch import autocast
+from contextlib import contextmanager, nullcontext
+from einops import rearrange, repeat
+from ldm.util import instantiate_from_config
+from optimUtils import split_weighted_subprompts, logger
+from transformers import logging
+import pandas as pd
+logging.set_verbosity_error()
+
+
+def chunk(it, size):
+ it = iter(it)
+ return iter(lambda: tuple(islice(it, size)), ())
+
+
+def load_model_from_config(ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ return sd
+
+
+def load_img(path, h0, w0):
+
+ image = Image.open(path).convert("RGB")
+ w, h = image.size
+
+ print(f"loaded input image of size ({w}, {h}) from {path}")
+ if h0 is not None and w0 is not None:
+ h, w = h0, w0
+
+ w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
+
+ print(f"New image size ({w}, {h})")
+ image = image.resize((w, h), resample=Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+config = "optimizedSD/v1-inference.yaml"
+ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument(
+ "--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render"
+)
+parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples")
+parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image")
+
+parser.add_argument(
+ "--skip_grid",
+ action="store_true",
+ help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
+)
+parser.add_argument(
+ "--skip_save",
+ action="store_true",
+ help="do not save individual samples. For speed measurements.",
+)
+parser.add_argument(
+ "--ddim_steps",
+ type=int,
+ default=50,
+ help="number of ddim sampling steps",
+)
+
+parser.add_argument(
+ "--ddim_eta",
+ type=float,
+ default=0.0,
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
+)
+parser.add_argument(
+ "--n_iter",
+ type=int,
+ default=1,
+ help="sample this often",
+)
+parser.add_argument(
+ "--H",
+ type=int,
+ default=None,
+ help="image height, in pixel space",
+)
+parser.add_argument(
+ "--W",
+ type=int,
+ default=None,
+ help="image width, in pixel space",
+)
+parser.add_argument(
+ "--strength",
+ type=float,
+ default=0.75,
+ help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
+)
+parser.add_argument(
+ "--n_samples",
+ type=int,
+ default=5,
+ help="how many samples to produce for each given prompt. A.k.a. batch size",
+)
+parser.add_argument(
+ "--n_rows",
+ type=int,
+ default=0,
+ help="rows in the grid (default: n_samples)",
+)
+parser.add_argument(
+ "--scale",
+ type=float,
+ default=7.5,
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
+)
+parser.add_argument(
+ "--from-file",
+ type=str,
+ help="if specified, load prompts from this file",
+)
+parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="the seed (for reproducible sampling)",
+)
+parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ help="CPU or GPU (cuda/cuda:0/cuda:1/...)",
+)
+parser.add_argument(
+ "--unet_bs",
+ type=int,
+ default=1,
+ help="Slightly reduces inference time at the expense of high VRAM (value > 1 not recommended )",
+)
+parser.add_argument(
+ "--turbo",
+ action="store_true",
+ help="Reduces inference time on the expense of 1GB VRAM",
+)
+parser.add_argument(
+ "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
+)
+parser.add_argument(
+ "--format",
+ type=str,
+ help="output image format",
+ choices=["jpg", "png"],
+ default="png",
+)
+parser.add_argument(
+ "--sampler",
+ type=str,
+ help="sampler",
+ choices=["ddim"],
+ default="ddim",
+)
+opt = parser.parse_args()
+
+tic = time.time()
+os.makedirs(opt.outdir, exist_ok=True)
+outpath = opt.outdir
+grid_count = len(os.listdir(outpath)) - 1
+
+if opt.seed == None:
+ opt.seed = randint(0, 1000000)
+seed_everything(opt.seed)
+
+# Logging
+logger(vars(opt), log_csv = "logs/img2img_logs.csv")
+
+sd = load_model_from_config(f"{ckpt}")
+li, lo = [], []
+for key, value in sd.items():
+ sp = key.split(".")
+ if (sp[0]) == "model":
+ if "input_blocks" in sp:
+ li.append(key)
+ elif "middle_block" in sp:
+ li.append(key)
+ elif "time_embed" in sp:
+ li.append(key)
+ else:
+ lo.append(key)
+for key in li:
+ sd["model1." + key[6:]] = sd.pop(key)
+for key in lo:
+ sd["model2." + key[6:]] = sd.pop(key)
+
+config = OmegaConf.load(f"{config}")
+
+assert os.path.isfile(opt.init_img)
+init_image = load_img(opt.init_img, opt.H, opt.W).to(opt.device)
+
+model = instantiate_from_config(config.modelUNet)
+_, _ = model.load_state_dict(sd, strict=False)
+model.eval()
+model.cdevice = opt.device
+model.unet_bs = opt.unet_bs
+model.turbo = opt.turbo
+
+modelCS = instantiate_from_config(config.modelCondStage)
+_, _ = modelCS.load_state_dict(sd, strict=False)
+modelCS.eval()
+modelCS.cond_stage_model.device = opt.device
+
+modelFS = instantiate_from_config(config.modelFirstStage)
+_, _ = modelFS.load_state_dict(sd, strict=False)
+modelFS.eval()
+del sd
+if opt.device != "cpu" and opt.precision == "autocast":
+ model.half()
+ modelCS.half()
+ modelFS.half()
+ init_image = init_image.half()
+
+batch_size = opt.n_samples
+n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
+if not opt.from_file:
+ assert opt.prompt is not None
+ prompt = opt.prompt
+ data = [batch_size * [prompt]]
+
+else:
+ print(f"reading prompts from {opt.from_file}")
+ with open(opt.from_file, "r") as f:
+ data = f.read().splitlines()
+ data = batch_size * list(data)
+ data = list(chunk(sorted(data), batch_size))
+
+modelFS.to(opt.device)
+
+init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
+init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
+
+if opt.device != "cpu":
+ mem = torch.cuda.memory_allocated(device=opt.device) / 1e6
+ modelFS.to("cpu")
+ while torch.cuda.memory_allocated(device=opt.device) / 1e6 >= mem:
+ time.sleep(1)
+
+
+assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]"
+t_enc = int(opt.strength * opt.ddim_steps)
+print(f"target t_enc is {t_enc} steps")
+
+
+if opt.precision == "autocast" and opt.device != "cpu":
+ precision_scope = autocast
+else:
+ precision_scope = nullcontext
+
+seeds = ""
+with torch.no_grad():
+
+ all_samples = list()
+ for n in trange(opt.n_iter, desc="Sampling"):
+ for prompts in tqdm(data, desc="data"):
+
+ sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompts[0])))[:150]
+ os.makedirs(sample_path, exist_ok=True)
+ base_count = len(os.listdir(sample_path))
+
+ with precision_scope("cuda"):
+ modelCS.to(opt.device)
+ uc = None
+ if opt.scale != 1.0:
+ uc = modelCS.get_learned_conditioning(batch_size * [""])
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+
+ subprompts, weights = split_weighted_subprompts(prompts[0])
+ if len(subprompts) > 1:
+ c = torch.zeros_like(uc)
+ totalWeight = sum(weights)
+ # normalize each "sub prompt" and add it
+ for i in range(len(subprompts)):
+ weight = weights[i]
+ # if not skip_normalize:
+ weight = weight / totalWeight
+ c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
+ else:
+ c = modelCS.get_learned_conditioning(prompts)
+
+ if opt.device != "cpu":
+ mem = torch.cuda.memory_allocated(device=opt.device) / 1e6
+ modelCS.to("cpu")
+ while torch.cuda.memory_allocated(device=opt.device) / 1e6 >= mem:
+ time.sleep(1)
+
+ # encode (scaled latent)
+ z_enc = model.stochastic_encode(
+ init_latent,
+ torch.tensor([t_enc] * batch_size).to(opt.device),
+ opt.seed,
+ opt.ddim_eta,
+ opt.ddim_steps,
+ )
+ # decode it
+ samples_ddim = model.sample(
+ t_enc,
+ c,
+ z_enc,
+ unconditional_guidance_scale=opt.scale,
+ unconditional_conditioning=uc,
+ sampler = opt.sampler
+ )
+
+ modelFS.to(opt.device)
+ print("saving images")
+ for i in range(batch_size):
+
+ x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
+ x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
+ Image.fromarray(x_sample.astype(np.uint8)).save(
+ os.path.join(sample_path, "seed_" + str(opt.seed) + "_" + f"{base_count:05}.{opt.format}")
+ )
+ seeds += str(opt.seed) + ","
+ opt.seed += 1
+ base_count += 1
+
+ if opt.device != "cpu":
+ mem = torch.cuda.memory_allocated(device=opt.device) / 1e6
+ modelFS.to("cpu")
+ while torch.cuda.memory_allocated(device=opt.device) / 1e6 >= mem:
+ time.sleep(1)
+
+ del samples_ddim
+ print("memory_final = ", torch.cuda.memory_allocated(device=opt.device) / 1e6)
+
+toc = time.time()
+
+time_taken = (toc - tic) / 60.0
+
+print(
+ (
+ "Samples finished in {0:.2f} minutes and exported to "
+ + sample_path
+ + "\n Seeds used = "
+ + seeds[:-1]
+ ).format(time_taken)
+)
diff --git a/optimizedSD/optimized_txt2img.py b/optimizedSD/optimized_txt2img.py
new file mode 100644
index 000000000..c82918240
--- /dev/null
+++ b/optimizedSD/optimized_txt2img.py
@@ -0,0 +1,347 @@
+import argparse, os, re
+import torch
+import numpy as np
+from random import randint
+from omegaconf import OmegaConf
+from PIL import Image
+from tqdm import tqdm, trange
+from itertools import islice
+from einops import rearrange
+from torchvision.utils import make_grid
+import time
+from pytorch_lightning import seed_everything
+from torch import autocast
+from contextlib import contextmanager, nullcontext
+from ldm.util import instantiate_from_config
+from optimUtils import split_weighted_subprompts, logger
+from transformers import logging
+# from samplers import CompVisDenoiser
+logging.set_verbosity_error()
+
+
+def chunk(it, size):
+ it = iter(it)
+ return iter(lambda: tuple(islice(it, size)), ())
+
+
+def load_model_from_config(ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ return sd
+
+
+config = "optimizedSD/v1-inference.yaml"
+DEFAULT_CKPT = "models/ldm/stable-diffusion-v1/model.ckpt"
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument(
+ "--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render"
+)
+parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples")
+parser.add_argument(
+ "--skip_grid",
+ action="store_true",
+ help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
+)
+parser.add_argument(
+ "--skip_save",
+ action="store_true",
+ help="do not save individual samples. For speed measurements.",
+)
+parser.add_argument(
+ "--ddim_steps",
+ type=int,
+ default=50,
+ help="number of ddim sampling steps",
+)
+
+parser.add_argument(
+ "--fixed_code",
+ action="store_true",
+ help="if enabled, uses the same starting code across samples ",
+)
+parser.add_argument(
+ "--ddim_eta",
+ type=float,
+ default=0.0,
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
+)
+parser.add_argument(
+ "--n_iter",
+ type=int,
+ default=1,
+ help="sample this often",
+)
+parser.add_argument(
+ "--H",
+ type=int,
+ default=512,
+ help="image height, in pixel space",
+)
+parser.add_argument(
+ "--W",
+ type=int,
+ default=512,
+ help="image width, in pixel space",
+)
+parser.add_argument(
+ "--C",
+ type=int,
+ default=4,
+ help="latent channels",
+)
+parser.add_argument(
+ "--f",
+ type=int,
+ default=8,
+ help="downsampling factor",
+)
+parser.add_argument(
+ "--n_samples",
+ type=int,
+ default=5,
+ help="how many samples to produce for each given prompt. A.k.a. batch size",
+)
+parser.add_argument(
+ "--n_rows",
+ type=int,
+ default=0,
+ help="rows in the grid (default: n_samples)",
+)
+parser.add_argument(
+ "--scale",
+ type=float,
+ default=7.5,
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
+)
+parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ help="specify GPU (cuda/cuda:0/cuda:1/...)",
+)
+parser.add_argument(
+ "--from-file",
+ type=str,
+ help="if specified, load prompts from this file",
+)
+parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="the seed (for reproducible sampling)",
+)
+parser.add_argument(
+ "--unet_bs",
+ type=int,
+ default=1,
+ help="Slightly reduces inference time at the expense of high VRAM (value > 1 not recommended )",
+)
+parser.add_argument(
+ "--turbo",
+ action="store_true",
+ help="Reduces inference time on the expense of 1GB VRAM",
+)
+parser.add_argument(
+ "--precision",
+ type=str,
+ help="evaluate at this precision",
+ choices=["full", "autocast"],
+ default="autocast"
+)
+parser.add_argument(
+ "--format",
+ type=str,
+ help="output image format",
+ choices=["jpg", "png"],
+ default="png",
+)
+parser.add_argument(
+ "--sampler",
+ type=str,
+ help="sampler",
+ choices=["ddim", "plms","heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"],
+ default="plms",
+)
+parser.add_argument(
+ "--ckpt",
+ type=str,
+ help="path to checkpoint of model",
+ default=DEFAULT_CKPT,
+)
+opt = parser.parse_args()
+
+tic = time.time()
+os.makedirs(opt.outdir, exist_ok=True)
+outpath = opt.outdir
+grid_count = len(os.listdir(outpath)) - 1
+
+if opt.seed == None:
+ opt.seed = randint(0, 1000000)
+seed_everything(opt.seed)
+
+# Logging
+logger(vars(opt), log_csv = "logs/txt2img_logs.csv")
+
+sd = load_model_from_config(f"{opt.ckpt}")
+li, lo = [], []
+for key, value in sd.items():
+ sp = key.split(".")
+ if (sp[0]) == "model":
+ if "input_blocks" in sp:
+ li.append(key)
+ elif "middle_block" in sp:
+ li.append(key)
+ elif "time_embed" in sp:
+ li.append(key)
+ else:
+ lo.append(key)
+for key in li:
+ sd["model1." + key[6:]] = sd.pop(key)
+for key in lo:
+ sd["model2." + key[6:]] = sd.pop(key)
+
+config = OmegaConf.load(f"{config}")
+
+model = instantiate_from_config(config.modelUNet)
+_, _ = model.load_state_dict(sd, strict=False)
+model.eval()
+model.unet_bs = opt.unet_bs
+model.cdevice = opt.device
+model.turbo = opt.turbo
+
+modelCS = instantiate_from_config(config.modelCondStage)
+_, _ = modelCS.load_state_dict(sd, strict=False)
+modelCS.eval()
+modelCS.cond_stage_model.device = opt.device
+
+modelFS = instantiate_from_config(config.modelFirstStage)
+_, _ = modelFS.load_state_dict(sd, strict=False)
+modelFS.eval()
+del sd
+
+if opt.device != "cpu" and opt.precision == "autocast":
+ model.half()
+ modelCS.half()
+
+start_code = None
+if opt.fixed_code:
+ start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=opt.device)
+
+
+batch_size = opt.n_samples
+n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
+if not opt.from_file:
+ assert opt.prompt is not None
+ prompt = opt.prompt
+ print(f"Using prompt: {prompt}")
+ data = [batch_size * [prompt]]
+
+else:
+ print(f"reading prompts from {opt.from_file}")
+ with open(opt.from_file, "r") as f:
+ text = f.read()
+ print(f"Using prompt: {text.strip()}")
+ data = text.splitlines()
+ data = batch_size * list(data)
+ data = list(chunk(sorted(data), batch_size))
+
+
+if opt.precision == "autocast" and opt.device != "cpu":
+ precision_scope = autocast
+else:
+ precision_scope = nullcontext
+
+seeds = ""
+with torch.no_grad():
+
+ all_samples = list()
+ for n in trange(opt.n_iter, desc="Sampling"):
+ for prompts in tqdm(data, desc="data"):
+
+ sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompts[0])))[:150]
+ os.makedirs(sample_path, exist_ok=True)
+ base_count = len(os.listdir(sample_path))
+
+ with precision_scope("cuda"):
+ modelCS.to(opt.device)
+ uc = None
+ if opt.scale != 1.0:
+ uc = modelCS.get_learned_conditioning(batch_size * [""])
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+
+ subprompts, weights = split_weighted_subprompts(prompts[0])
+ if len(subprompts) > 1:
+ c = torch.zeros_like(uc)
+ totalWeight = sum(weights)
+ # normalize each "sub prompt" and add it
+ for i in range(len(subprompts)):
+ weight = weights[i]
+ # if not skip_normalize:
+ weight = weight / totalWeight
+ c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
+ else:
+ c = modelCS.get_learned_conditioning(prompts)
+
+ shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f]
+
+ if opt.device != "cpu":
+ mem = torch.cuda.memory_allocated() / 1e6
+ modelCS.to("cpu")
+ while torch.cuda.memory_allocated() / 1e6 >= mem:
+ time.sleep(1)
+
+ samples_ddim = model.sample(
+ S=opt.ddim_steps,
+ conditioning=c,
+ seed=opt.seed,
+ shape=shape,
+ verbose=False,
+ unconditional_guidance_scale=opt.scale,
+ unconditional_conditioning=uc,
+ eta=opt.ddim_eta,
+ x_T=start_code,
+ sampler = opt.sampler,
+ )
+
+ modelFS.to(opt.device)
+
+ print(samples_ddim.shape)
+ print("saving images")
+ for i in range(batch_size):
+
+ x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
+ x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
+ Image.fromarray(x_sample.astype(np.uint8)).save(
+ os.path.join(sample_path, "seed_" + str(opt.seed) + "_" + f"{base_count:05}.{opt.format}")
+ )
+ seeds += str(opt.seed) + ","
+ opt.seed += 1
+ base_count += 1
+
+ if opt.device != "cpu":
+ mem = torch.cuda.memory_allocated() / 1e6
+ modelFS.to("cpu")
+ while torch.cuda.memory_allocated() / 1e6 >= mem:
+ time.sleep(1)
+ del samples_ddim
+ print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
+
+toc = time.time()
+
+time_taken = (toc - tic) / 60.0
+
+print(
+ (
+ "Samples finished in {0:.2f} minutes and exported to "
+ + sample_path
+ + "\n Seeds used = "
+ + seeds[:-1]
+ ).format(time_taken)
+)
diff --git a/optimizedSD/samplers.py b/optimizedSD/samplers.py
new file mode 100644
index 000000000..6a68e8e1a
--- /dev/null
+++ b/optimizedSD/samplers.py
@@ -0,0 +1,252 @@
+from scipy import integrate
+import torch
+from tqdm.auto import trange, tqdm
+import torch.nn as nn
+
+
+def append_zero(x):
+ return torch.cat([x, x.new_zeros([1])])
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+ return x[(...,) + (None,) * dims_to_append]
+
+def get_ancestral_step(sigma_from, sigma_to):
+ """Calculates the noise level (sigma_down) to step down to and the amount
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
+ sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
+ return sigma_down, sigma_up
+
+
+class DiscreteSchedule(nn.Module):
+ """A mapping between continuous noise levels (sigmas) and a list of discrete noise
+ levels."""
+
+ def __init__(self, sigmas, quantize):
+ super().__init__()
+ self.register_buffer('sigmas', sigmas)
+ self.quantize = quantize
+
+ def get_sigmas(self, n=None):
+ if n is None:
+ return append_zero(self.sigmas.flip(0))
+ t_max = len(self.sigmas) - 1
+ t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
+ return append_zero(self.t_to_sigma(t))
+
+ def sigma_to_t(self, sigma, quantize=None):
+ quantize = self.quantize if quantize is None else quantize
+ dists = torch.abs(sigma - self.sigmas[:, None])
+ if quantize:
+ return torch.argmin(dists, dim=0).view(sigma.shape)
+ low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
+ low, high = self.sigmas[low_idx], self.sigmas[high_idx]
+ w = (low - sigma) / (low - high)
+ w = w.clamp(0, 1)
+ t = (1 - w) * low_idx + w * high_idx
+ return t.view(sigma.shape)
+
+ def t_to_sigma(self, t):
+ t = t.float()
+ low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
+ # print(low_idx, high_idx, w )
+ return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx]
+
+
+class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
+ """A wrapper for discrete schedule DDPM models that output eps (the predicted
+ noise)."""
+
+ def __init__(self, alphas_cumprod, quantize):
+ super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
+ self.sigma_data = 1.
+
+ def get_scalings(self, sigma):
+ c_out = -sigma
+ c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+ return c_out, c_in
+
+ def get_eps(self, *args, **kwargs):
+ return self.inner_model(*args, **kwargs)
+
+ def forward(self, input, sigma, **kwargs):
+ c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
+ return input + eps * c_out
+
+class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
+ """A wrapper for CompVis diffusion models."""
+
+ def __init__(self, alphas_cumprod, quantize=False, device='cpu'):
+ super().__init__(alphas_cumprod, quantize=quantize)
+
+ def get_eps(self, *args, **kwargs):
+ return self.inner_model.apply_model(*args, **kwargs)
+
+
+def to_d(x, sigma, denoised):
+ """Converts a denoiser output to a Karras ODE derivative."""
+ return (x - denoised) / append_dims(sigma, x.ndim)
+
+
+def get_ancestral_step(sigma_from, sigma_to):
+ """Calculates the noise level (sigma_down) to step down to and the amount
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
+ sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
+ return sigma_down, sigma_up
+
+
+@torch.no_grad()
+def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
+ eps = torch.randn_like(x) * s_noise
+ sigma_hat = sigmas[i] * (gamma + 1)
+ if gamma > 0:
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
+ denoised = model(x, sigma_hat * s_in, **extra_args)
+ d = to_d(x, sigma_hat, denoised)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
+ dt = sigmas[i + 1] - sigma_hat
+ # Euler method
+ x = x + d * dt
+ return x
+
+
+
+@torch.no_grad()
+def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
+ """Ancestral sampling with Euler method steps."""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ d = to_d(x, sigmas[i], denoised)
+ # Euler method
+ dt = sigma_down - sigmas[i]
+ x = x + d * dt
+ x = x + torch.randn_like(x) * sigma_up
+ return x
+
+
+@torch.no_grad()
+def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
+ eps = torch.randn_like(x) * s_noise
+ sigma_hat = sigmas[i] * (gamma + 1)
+ if gamma > 0:
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
+ denoised = model(x, sigma_hat * s_in, **extra_args)
+ d = to_d(x, sigma_hat, denoised)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
+ dt = sigmas[i + 1] - sigma_hat
+ if sigmas[i + 1] == 0:
+ # Euler method
+ x = x + d * dt
+ else:
+ # Heun's method
+ x_2 = x + d * dt
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
+ d_prime = (d + d_2) / 2
+ x = x + d_prime * dt
+ return x
+
+
+@torch.no_grad()
+def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
+ eps = torch.randn_like(x) * s_noise
+ sigma_hat = sigmas[i] * (gamma + 1)
+ if gamma > 0:
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
+ denoised = model(x, sigma_hat * s_in, **extra_args)
+ d = to_d(x, sigma_hat, denoised)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
+ # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
+ sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
+ dt_1 = sigma_mid - sigma_hat
+ dt_2 = sigmas[i + 1] - sigma_hat
+ x_2 = x + d * dt_1
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
+ x = x + d_2 * dt_2
+ return x
+
+
+@torch.no_grad()
+def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
+ """Ancestral sampling with DPM-Solver inspired second-order steps."""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ d = to_d(x, sigmas[i], denoised)
+ # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
+ sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
+ dt_1 = sigma_mid - sigmas[i]
+ dt_2 = sigma_down - sigmas[i]
+ x_2 = x + d * dt_1
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
+ x = x + d_2 * dt_2
+ x = x + torch.randn_like(x) * sigma_up
+ return x
+
+
+def linear_multistep_coeff(order, t, i, j):
+ if order - 1 > i:
+ raise ValueError(f'Order {order} too high for step {i}')
+ def fn(tau):
+ prod = 1.
+ for k in range(order):
+ if j == k:
+ continue
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
+ return prod
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
+
+
+@torch.no_grad()
+def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ ds = []
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ d = to_d(x, sigmas[i], denoised)
+ ds.append(d)
+ if len(ds) > order:
+ ds.pop(0)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ cur_order = min(i + 1, order)
+ coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
+ return x
diff --git a/optimizedSD/splitAttention.py b/optimizedSD/splitAttention.py
new file mode 100644
index 000000000..dbfd459e4
--- /dev/null
+++ b/optimizedSD/splitAttention.py
@@ -0,0 +1,280 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+
+from ldm.modules.diffusionmodules.util import checkpoint
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., att_step=1):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ self.att_step = att_step
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ del context, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+
+ limit = k.shape[0]
+ att_step = self.att_step
+ q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0))
+ k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0))
+ v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0))
+
+ q_chunks.reverse()
+ k_chunks.reverse()
+ v_chunks.reverse()
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+ del k, q, v
+ for i in range (0, limit, att_step):
+
+ q_buffer = q_chunks.pop()
+ k_buffer = k_chunks.pop()
+ v_buffer = v_chunks.pop()
+ sim_buffer = einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
+
+ del k_buffer, q_buffer
+ # attention, what we cannot get enough of, by chunks
+
+ sim_buffer = sim_buffer.softmax(dim=-1)
+
+ sim_buffer = einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
+ del v_buffer
+ sim[i:i+att_step,:,:] = sim_buffer
+
+ del sim_buffer
+ sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(sim)
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
+ super().__init__()
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x)) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None):
+ super().__init__()
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
+ for d in range(depth)]
+ )
+
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ for block in self.transformer_blocks:
+ x = block(x, context=context)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
+ x = self.proj_out(x)
+ return x + x_in
diff --git a/optimizedSD/txt2img_gradio.py b/optimizedSD/txt2img_gradio.py
new file mode 100644
index 000000000..a909351f1
--- /dev/null
+++ b/optimizedSD/txt2img_gradio.py
@@ -0,0 +1,250 @@
+import gradio as gr
+import numpy as np
+import torch
+from torchvision.utils import make_grid
+from einops import rearrange
+import os, re
+from PIL import Image
+import torch
+import pandas as pd
+import numpy as np
+from random import randint
+from omegaconf import OmegaConf
+from PIL import Image
+from tqdm import tqdm, trange
+from itertools import islice
+from einops import rearrange
+from torchvision.utils import make_grid
+import time
+from pytorch_lightning import seed_everything
+from torch import autocast
+from contextlib import nullcontext
+from ldm.util import instantiate_from_config
+from optimUtils import split_weighted_subprompts, logger
+from transformers import logging
+logging.set_verbosity_error()
+import mimetypes
+mimetypes.init()
+mimetypes.add_type("application/javascript", ".js")
+
+
+def chunk(it, size):
+ it = iter(it)
+ return iter(lambda: tuple(islice(it, size)), ())
+
+
+def load_model_from_config(ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ return sd
+
+config = "optimizedSD/v1-inference.yaml"
+ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
+sd = load_model_from_config(f"{ckpt}")
+li, lo = [], []
+for key, v_ in sd.items():
+ sp = key.split(".")
+ if (sp[0]) == "model":
+ if "input_blocks" in sp:
+ li.append(key)
+ elif "middle_block" in sp:
+ li.append(key)
+ elif "time_embed" in sp:
+ li.append(key)
+ else:
+ lo.append(key)
+for key in li:
+ sd["model1." + key[6:]] = sd.pop(key)
+for key in lo:
+ sd["model2." + key[6:]] = sd.pop(key)
+
+config = OmegaConf.load(f"{config}")
+
+model = instantiate_from_config(config.modelUNet)
+_, _ = model.load_state_dict(sd, strict=False)
+model.eval()
+
+modelCS = instantiate_from_config(config.modelCondStage)
+_, _ = modelCS.load_state_dict(sd, strict=False)
+modelCS.eval()
+
+modelFS = instantiate_from_config(config.modelFirstStage)
+_, _ = modelFS.load_state_dict(sd, strict=False)
+modelFS.eval()
+del sd
+
+
+def generate(
+ prompt,
+ ddim_steps,
+ n_iter,
+ batch_size,
+ Height,
+ Width,
+ scale,
+ ddim_eta,
+ unet_bs,
+ device,
+ seed,
+ outdir,
+ img_format,
+ turbo,
+ full_precision,
+ sampler,
+):
+
+ C = 4
+ f = 8
+ start_code = None
+ model.unet_bs = unet_bs
+ model.turbo = turbo
+ model.cdevice = device
+ modelCS.cond_stage_model.device = device
+
+ if seed == "":
+ seed = randint(0, 1000000)
+ seed = int(seed)
+ seed_everything(seed)
+ # Logging
+ logger(locals(), "logs/txt2img_gradio_logs.csv")
+
+ if device != "cpu" and full_precision == False:
+ model.half()
+ modelFS.half()
+ modelCS.half()
+
+ tic = time.time()
+ os.makedirs(outdir, exist_ok=True)
+ outpath = outdir
+ sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompt)))[:150]
+ os.makedirs(sample_path, exist_ok=True)
+ base_count = len(os.listdir(sample_path))
+
+ # n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
+ assert prompt is not None
+ data = [batch_size * [prompt]]
+
+ if full_precision == False and device != "cpu":
+ precision_scope = autocast
+ else:
+ precision_scope = nullcontext
+
+ all_samples = []
+ seeds = ""
+ with torch.no_grad():
+
+ all_samples = list()
+ for _ in trange(n_iter, desc="Sampling"):
+ for prompts in tqdm(data, desc="data"):
+ with precision_scope("cuda"):
+ modelCS.to(device)
+ uc = None
+ if scale != 1.0:
+ uc = modelCS.get_learned_conditioning(batch_size * [""])
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+
+ subprompts, weights = split_weighted_subprompts(prompts[0])
+ if len(subprompts) > 1:
+ c = torch.zeros_like(uc)
+ totalWeight = sum(weights)
+ # normalize each "sub prompt" and add it
+ for i in range(len(subprompts)):
+ weight = weights[i]
+ # if not skip_normalize:
+ weight = weight / totalWeight
+ c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
+ else:
+ c = modelCS.get_learned_conditioning(prompts)
+
+ shape = [batch_size, C, Height // f, Width // f]
+
+ if device != "cpu":
+ mem = torch.cuda.memory_allocated() / 1e6
+ modelCS.to("cpu")
+ while torch.cuda.memory_allocated() / 1e6 >= mem:
+ time.sleep(1)
+
+ samples_ddim = model.sample(
+ S=ddim_steps,
+ conditioning=c,
+ seed=seed,
+ shape=shape,
+ verbose=False,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=uc,
+ eta=ddim_eta,
+ x_T=start_code,
+ sampler = sampler,
+ )
+
+ modelFS.to(device)
+ print("saving images")
+ for i in range(batch_size):
+
+ x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
+ x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ all_samples.append(x_sample.to("cpu"))
+ x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
+ Image.fromarray(x_sample.astype(np.uint8)).save(
+ os.path.join(sample_path, "seed_" + str(seed) + "_" + f"{base_count:05}.{img_format}")
+ )
+ seeds += str(seed) + ","
+ seed += 1
+ base_count += 1
+
+ if device != "cpu":
+ mem = torch.cuda.memory_allocated() / 1e6
+ modelFS.to("cpu")
+ while torch.cuda.memory_allocated() / 1e6 >= mem:
+ time.sleep(1)
+
+ del samples_ddim
+ del x_sample
+ del x_samples_ddim
+ print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
+
+ toc = time.time()
+
+ time_taken = (toc - tic) / 60.0
+ grid = torch.cat(all_samples, 0)
+ grid = make_grid(grid, nrow=n_iter)
+ grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
+
+ txt = (
+ "Samples finished in "
+ + str(round(time_taken, 3))
+ + " minutes and exported to "
+ + sample_path
+ + "\nSeeds used = "
+ + seeds[:-1]
+ )
+ return Image.fromarray(grid.astype(np.uint8)), txt
+
+
+demo = gr.Interface(
+ fn=generate,
+ inputs=[
+ "text",
+ gr.Slider(1, 1000, value=50),
+ gr.Slider(1, 100, step=1),
+ gr.Slider(1, 100, step=1),
+ gr.Slider(64, 4096, value=512, step=64),
+ gr.Slider(64, 4096, value=512, step=64),
+ gr.Slider(0, 50, value=7.5, step=0.1),
+ gr.Slider(0, 1, step=0.01),
+ gr.Slider(1, 2, value=1, step=1),
+ gr.Text(value="cuda"),
+ "text",
+ gr.Text(value="outputs/txt2img-samples"),
+ gr.Radio(["png", "jpg"], value='png'),
+ "checkbox",
+ "checkbox",
+ gr.Radio(["ddim", "plms","heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"], value="plms"),
+ ],
+ outputs=["image", "text"],
+)
+demo.launch()
diff --git a/optimizedSD/v1-inference.yaml b/optimizedSD/v1-inference.yaml
new file mode 100644
index 000000000..2e535fcb4
--- /dev/null
+++ b/optimizedSD/v1-inference.yaml
@@ -0,0 +1,114 @@
+modelUNet:
+ base_learning_rate: 1.0e-04
+ target: optimizedSD.ddpm.UNet
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ unetConfigEncode:
+ target: optimizedSD.openaimodelSplit.UNetModelEncode
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [4, 2, 1]
+ num_res_blocks: 2
+ channel_mult: [1, 2, 4, 4]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ unetConfigDecode:
+ target: optimizedSD.openaimodelSplit.UNetModelDecode
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [4, 2, 1]
+ num_res_blocks: 2
+ channel_mult: [1, 2, 4, 4]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+modelFirstStage:
+ target: optimizedSD.ddpm.FirstStage
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+modelCondStage:
+ target: optimizedSD.ddpm.CondStage
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
+ params:
+ device: cpu