-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Würstchen model #3849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Würstchen model #3849
Changes from all commits
119a451
0623199
8a6a92c
80713a4
8bd6cb8
806ed12
ff6139d
560da3b
3acc9fa
f84ac09
25de2c6
30e41a5
d563218
d328459
f0cc379
4c8a791
ad474b1
a79a9ad
4c28f9c
6e51d7e
623c1e4
cd5ad04
92c46df
4665e48
58c98b1
66cff25
d06276d
95eb11e
2976dd8
896624e
1ed7a58
e809fd7
0ad3f79
2edfc48
624c6d9
0d3c3f3
38fa6d1
ea2c64e
96cb4de
2c6d0dd
3281798
59667d9
156901f
ddc9daa
30e4888
fe972d6
db5dd65
44d6a04
180bbae
715e999
368e113
0c0bedc
1247ae9
b0435b3
9bdc662
6f9cd37
c05d6c5
2d9d85d
b9c3468
a385e66
0db4f19
41c47cc
3e294c6
3ae3ea4
b3b2b60
8663037
9ec3f01
83f87ef
7a3639d
3791b94
29610d2
cee4feb
f91b12e
5b518a2
44c0f93
55ba4db
48034a0
be2529a
be1aa96
a1114e3
596c7f5
e708a76
4ae05f3
71b7fa3
36e9722
f620d83
61c137c
e5127d5
61a5ebc
c89b8a4
b122ddd
b50ce49
f24ee47
ce23ef7
163cd2b
59e5f15
e6f0f75
35e55a7
78cd405
3ddee34
fca022a
903ba6f
67eaff6
3731131
09ca25b
8f7a74a
9bbfb7c
cc80e2b
f74f688
829a394
170180a
f9f34aa
687de06
0ca12ee
433bded
945295d
05b58bc
b0dc35c
3e08530
711246a
5ca1fe0
1f4bb0a
474ec70
54d3397
98f4b54
b43b463
bb8c5b1
2ee6ed9
d944bb1
59eb765
b5ff681
7d6f2e0
f33dd22
d230c8b
17d28e3
3872a96
3f49d52
5ea91e1
752d3f5
eadd628
fda1f68
f8ddabd
ab22f62
45dcfe1
cbf8780
15b5f42
bd362df
d20f8ff
851705c
cdf5109
64cd513
7a24a7d
2971765
11cb295
23e8740
cc70ca5
32395d3
c2086cf
851115c
ecd6ab3
44ad4c3
fb45b37
d3e5919
2f325e6
db8fae2
a10f5d6
754b9ab
f162d77
c74f9c6
741f6ef
09781d2
b8f8cff
06cd467
7c15471
1d6615e
ab586bd
8fdce9d
56cfe6b
f9a9259
75d4060
cf67355
d45550b
6f978ed
50774c9
81f64de
83073b1
35772f1
21068df
7f39e0c
0b97829
b801a56
1e9336c
9ca78d9
9d8ea07
a09b4ef
1c568ce
500cb6e
692a1b7
ad98baf
2d222ed
d8b62f2
d4751ab
3b705e8
879c82c
8490804
88032f1
fb33746
6347957
5489081
30bc6b6
bc8a472
09787b1
ed9f96a
30a86b3
3f04ada
c35f3f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| # Würstchen | ||
|
|
||
| <img src="https://github.com/dome272/Wuerstchen/assets/61938694/0617c863-165a-43ee-9303-2a17299a0cf9"> | ||
|
|
||
| [Würstchen: Efficient Pretraining of Text-to-Image Models](https://huggingface.co/papers/2306.00637) is by Pablo Pernias, Dominic Rampas, and Marc Aubreville. | ||
|
|
||
| The abstract from the paper is: | ||
|
|
||
| *We introduce Würstchen, a novel technique for text-to-image synthesis that unites competitive performance with unprecedented cost-effectiveness and ease of training on constrained hardware. Building on recent advancements in machine learning, our approach, which utilizes latent diffusion strategies at strong latent image compression rates, significantly reduces the computational burden, typically associated with state-of-the-art models, while preserving, if not enhancing, the quality of generated images. Wuerstchen achieves notable speed improvements at inference time, thereby rendering real-time applications more viable. One of the key advantages of our method lies in its modest training requirements of only 9,200 GPU hours, slashing the usual costs significantly without compromising the end performance. In a comparison against the state-of-the-art, we found the approach to yield strong competitiveness. This paper opens the door to a new line of research that prioritizes both performance and computational accessibility, hence democratizing the use of sophisticated AI technologies. Through Wuerstchen, we demonstrate a compelling stride forward in the realm of text-to-image synthesis, offering an innovative path to explore in future research.* | ||
|
|
||
| ## Würstchen v2 comes to Diffusers | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| After the initial paper release, we have improved numerous things in the architecture, training and sampling, making Würstchen competetive to current state-of-the-art models in many ways. We are excited to release this new version together with Diffusers. Here is a list of the improvements. | ||
|
|
||
| - Higher resolution (1024x1024 up to 2048x2048) | ||
| - Faster inference | ||
| - Multi Aspect Resolution Sampling | ||
| - Better quality | ||
|
|
||
| We are releasing 3 checkpoints for the text-conditional image generation model (Stage C). Those are: | ||
| - v2-base | ||
| - v2-aesthetic | ||
| - v2-interpolated (50% interpolation between v2-base and v2-aesthetic) | ||
|
|
||
| We recommend to use v2-interpolated, as it has a nice touch of both photorealism and aesthetic. Use v2-base for finetunings as it does not have a style bias and use v2-aesthetic for very artistic generations. | ||
| A comparison can be seen here: | ||
|
|
||
| <img src="https://github.com/dome272/Wuerstchen/assets/61938694/2914830f-cbd3-461c-be64-d50734f4b49d" width=500> | ||
|
|
||
| ## Text-to-Image Generation | ||
|
|
||
| For the sake of usability Würstchen can be used with a single pipeline. This pipeline is called `WuerstchenCombinedPipeline` and can be used as follows: | ||
|
|
||
| ```python | ||
| import torch | ||
| from diffusers import AutoPipelineForText2Image | ||
|
|
||
| device = "cuda" | ||
| dtype = torch.float16 | ||
| num_images_per_prompt = 2 | ||
|
|
||
| pipeline = AutoPipelineForText2Image.from_pretrained( | ||
| "warp-diffusion/wuerstchen", torch_dtype=dtype | ||
| ).to(device) | ||
|
|
||
| caption = "Anthropomorphic cat dressed as a fire fighter" | ||
| negative_prompt = "" | ||
|
|
||
| output = pipeline( | ||
| prompt=caption, | ||
| height=1024, | ||
| width=1024, | ||
|
Comment on lines
+51
to
+52
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be a good idea to comment about the limitation here a bit I think. How does this pipeline perform when the resolution is low (as low as 256x256) and high (as high as 1024x1024). Also, are we mentioning somewhere that we don't need more than 10 inference steps to get good-quality results to demonstrate the competitive advantage of Wuerstchen?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will do that in the documentation! |
||
| negative_prompt=negative_prompt, | ||
| prior_guidance_scale=4.0, | ||
| decoder_guidance_scale=0.0, | ||
| num_images_per_prompt=num_images_per_prompt, | ||
| output_type="pil", | ||
| ).images | ||
| ``` | ||
|
|
||
| For explanation purposes, we can also initialize the two main pipelines of Würstchen individually. Würstchen consists of 3 stages: Stage C, Stage B, Stage A. They all have different jobs and work only together. When generating text-conditional images, Stage C will first generate the latents in a very compressed latent space. This is what happens in the `prior_pipeline`. Afterwards, the generated latents will be passed to Stage B, which decompresses the latents into a bigger latent space of a VQGAN. These latents can then be decoded by Stage A, which is a VQGAN, into the pixel-space. Stage B & Stage A are both encapsulated in the `decoder_pipeline`. For more details, take a look the [paper](https://huggingface.co/papers/2306.00637). | ||
|
|
||
| ```python | ||
| import torch | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can update this code snippet with a shorter one as stated here: #3849 (comment) |
||
| from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline | ||
|
|
||
| device = "cuda" | ||
| dtype = torch.float16 | ||
| num_images_per_prompt = 2 | ||
|
|
||
| prior_pipeline = WuerstchenPriorPipeline.from_pretrained( | ||
| "warp-diffusion/wuerstchen-prior", torch_dtype=dtype | ||
| ).to(device) | ||
| decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained( | ||
| "warp-diffusion/wuerstchen", torch_dtype=dtype | ||
| ).to(device) | ||
|
|
||
| caption = "A captivating artwork of a mysterious stone golem" | ||
| negative_prompt = "" | ||
|
|
||
| prior_output = prior_pipeline( | ||
| prompt=caption, | ||
| height=1024, | ||
| width=1024, | ||
| negative_prompt=negative_prompt, | ||
| guidance_scale=4.0, | ||
| num_images_per_prompt=num_images_per_prompt, | ||
| ) | ||
| decoder_output = decoder_pipeline( | ||
| image_embeddings=prior_output.image_embeddings, | ||
| prompt=caption, | ||
| negative_prompt=negative_prompt, | ||
| num_images_per_prompt=num_images_per_prompt, | ||
| guidance_scale=0.0, | ||
| output_type="pil", | ||
| ).images | ||
| ``` | ||
|
|
||
| ## Speed-Up Inference | ||
| You can make use of ``torch.compile`` function and gain a speed-up of about 2-3x: | ||
|
|
||
| ```python | ||
| pipeline.prior = torch.compile(pipeline.prior, mode="reduce-overhead", fullgraph=True) | ||
| pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True) | ||
| ``` | ||
|
|
||
| ## Limitations | ||
| - Due to the high compression employed by Würstchen, generations can lack a good amount | ||
| of detail. To our human eye, this is especially noticeable in faces, hands etc. | ||
| - **Images can only be generated in 128-pixel steps**, e.g. the next higher resolution | ||
| after 1024x1024 is 1152x1152 | ||
| - The model lacks the ability to render correct text in images | ||
| - The model often does not achieve photorealism | ||
| - Difficult compositional prompts are hard for the model | ||
|
|
||
|
|
||
| The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen). | ||
|
|
||
| ## WuerschenPipeline | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| [[autodoc]] WuerstchenCombinedPipeline | ||
| - all | ||
| - __call__ | ||
|
|
||
| ## WuerstchenPriorPipeline | ||
|
|
||
| [[autodoc]] WuerstchenDecoderPipeline | ||
|
|
||
| - all | ||
| - __call__ | ||
|
|
||
| ## WuerstchenPriorPipelineOutput | ||
|
|
||
| [[autodoc]] pipelines.wuerstchen.pipeline_wuerstchen_prior.WuerstchenPriorPipelineOutput | ||
|
|
||
| ## WuerstchenDecoderPipeline | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe let's shift this above the output dataclass? |
||
|
|
||
| [[autodoc]] WuerstchenDecoderPipeline | ||
| - all | ||
| - __call__ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| # Run inside root directory of official source code: https://github.com/dome272/wuerstchen/ | ||
| import os | ||
|
|
||
| import torch | ||
| from transformers import AutoTokenizer, CLIPTextModel | ||
| from vqgan import VQModel | ||
|
|
||
| from diffusers import ( | ||
| DDPMWuerstchenScheduler, | ||
| WuerstchenCombinedPipeline, | ||
| WuerstchenDecoderPipeline, | ||
| WuerstchenPriorPipeline, | ||
| ) | ||
| from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior | ||
|
|
||
|
|
||
| model_path = "models/" | ||
| device = "cpu" | ||
|
|
||
| paella_vqmodel = VQModel() | ||
| state_dict = torch.load(os.path.join(model_path, "vqgan_f4_v1_500k.pt"), map_location=device)["state_dict"] | ||
| paella_vqmodel.load_state_dict(state_dict) | ||
|
|
||
| state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"] | ||
| state_dict.pop("vquantizer.codebook.weight") | ||
| vqmodel = PaellaVQModel(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent) | ||
| vqmodel.load_state_dict(state_dict) | ||
|
|
||
| # Clip Text encoder and tokenizer | ||
| text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") | ||
| tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") | ||
|
|
||
| # Generator | ||
| gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") | ||
| gen_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") | ||
|
|
||
| orig_state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["state_dict"] | ||
| state_dict = {} | ||
| for key in orig_state_dict.keys(): | ||
| if key.endswith("in_proj_weight"): | ||
| weights = orig_state_dict[key].chunk(3, 0) | ||
| state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] | ||
| state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] | ||
| state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] | ||
| elif key.endswith("in_proj_bias"): | ||
| weights = orig_state_dict[key].chunk(3, 0) | ||
| state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] | ||
| state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] | ||
| state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] | ||
| elif key.endswith("out_proj.weight"): | ||
| weights = orig_state_dict[key] | ||
| state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights | ||
| elif key.endswith("out_proj.bias"): | ||
| weights = orig_state_dict[key] | ||
| state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights | ||
| else: | ||
| state_dict[key] = orig_state_dict[key] | ||
| deocder = WuerstchenDiffNeXt() | ||
| deocder.load_state_dict(state_dict) | ||
|
|
||
| # Prior | ||
| orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"] | ||
| state_dict = {} | ||
| for key in orig_state_dict.keys(): | ||
| if key.endswith("in_proj_weight"): | ||
| weights = orig_state_dict[key].chunk(3, 0) | ||
| state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] | ||
| state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] | ||
| state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] | ||
| elif key.endswith("in_proj_bias"): | ||
| weights = orig_state_dict[key].chunk(3, 0) | ||
| state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] | ||
| state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] | ||
| state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] | ||
| elif key.endswith("out_proj.weight"): | ||
| weights = orig_state_dict[key] | ||
| state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights | ||
| elif key.endswith("out_proj.bias"): | ||
| weights = orig_state_dict[key] | ||
| state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights | ||
| else: | ||
| state_dict[key] = orig_state_dict[key] | ||
| prior_model = WuerstchenPrior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) | ||
| prior_model.load_state_dict(state_dict) | ||
|
|
||
| # scheduler | ||
| scheduler = DDPMWuerstchenScheduler() | ||
|
|
||
| # Prior pipeline | ||
| prior_pipeline = WuerstchenPriorPipeline( | ||
| prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler | ||
| ) | ||
|
|
||
| prior_pipeline.save_pretrained("warp-diffusion/wuerstchen-prior") | ||
|
|
||
| decoder_pipeline = WuerstchenDecoderPipeline( | ||
| text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler | ||
| ) | ||
| decoder_pipeline.save_pretrained("warp-diffusion/wuerstchen") | ||
|
|
||
| # Wuerstchen pipeline | ||
| wuerstchen_pipeline = WuerstchenCombinedPipeline( | ||
| # Decoder | ||
| text_encoder=gen_text_encoder, | ||
| tokenizer=gen_tokenizer, | ||
| decoder=deocder, | ||
| scheduler=scheduler, | ||
| vqgan=vqmodel, | ||
| # Prior | ||
| prior_tokenizer=tokenizer, | ||
| prior_text_encoder=text_encoder, | ||
| prior=prior_model, | ||
| prior_scheduler=scheduler, | ||
| ) | ||
| wuerstchen_pipeline.save_pretrained("warp-diffusion/WuerstchenCombinedPipeline") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,6 +52,7 @@ | |
| StableDiffusionXLInpaintPipeline, | ||
| StableDiffusionXLPipeline, | ||
| ) | ||
| from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline | ||
|
|
||
|
|
||
| AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( | ||
|
|
@@ -63,6 +64,7 @@ | |
| ("kandinsky22", KandinskyV22CombinedPipeline), | ||
| ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline), | ||
| ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline), | ||
| ("wuerstchen", WuerstchenCombinedPipeline), | ||
| ] | ||
| ) | ||
|
|
||
|
|
@@ -93,6 +95,7 @@ | |
| [ | ||
| ("kandinsky", KandinskyPipeline), | ||
| ("kandinsky22", KandinskyV22Pipeline), | ||
| ("wuerstchen", WuerstchenDecoderPipeline), | ||
| ] | ||
| ) | ||
| _AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict( | ||
|
|
@@ -305,8 +308,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): | |
| use_auth_token = kwargs.pop("use_auth_token", None) | ||
| local_files_only = kwargs.pop("local_files_only", False) | ||
| revision = kwargs.pop("revision", None) | ||
| subfolder = kwargs.pop("subfolder", None) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally this should have been fixed in another PR, but doing it here now instead. |
||
| user_agent = kwargs.pop("user_agent", {}) | ||
|
|
||
| load_config_kwargs = { | ||
| "cache_dir": cache_dir, | ||
|
|
@@ -316,8 +317,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): | |
| "use_auth_token": use_auth_token, | ||
| "local_files_only": local_files_only, | ||
| "revision": revision, | ||
| "subfolder": subfolder, | ||
| "user_agent": user_agent, | ||
| } | ||
|
|
||
| config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) | ||
|
|
@@ -580,8 +579,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): | |
| use_auth_token = kwargs.pop("use_auth_token", None) | ||
| local_files_only = kwargs.pop("local_files_only", False) | ||
| revision = kwargs.pop("revision", None) | ||
| subfolder = kwargs.pop("subfolder", None) | ||
| user_agent = kwargs.pop("user_agent", {}) | ||
|
|
||
| load_config_kwargs = { | ||
| "cache_dir": cache_dir, | ||
|
|
@@ -591,8 +588,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): | |
| "use_auth_token": use_auth_token, | ||
| "local_files_only": local_files_only, | ||
| "revision": revision, | ||
| "subfolder": subfolder, | ||
| "user_agent": user_agent, | ||
| } | ||
|
|
||
| config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) | ||
|
|
@@ -856,8 +851,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): | |
| use_auth_token = kwargs.pop("use_auth_token", None) | ||
| local_files_only = kwargs.pop("local_files_only", False) | ||
| revision = kwargs.pop("revision", None) | ||
| subfolder = kwargs.pop("subfolder", None) | ||
| user_agent = kwargs.pop("user_agent", {}) | ||
|
|
||
| load_config_kwargs = { | ||
| "cache_dir": cache_dir, | ||
|
|
@@ -867,8 +860,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): | |
| "use_auth_token": use_auth_token, | ||
| "local_files_only": local_files_only, | ||
| "revision": revision, | ||
| "subfolder": subfolder, | ||
| "user_agent": user_agent, | ||
| } | ||
|
|
||
| config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) | ||
|
|
||

Uh oh!
There was an error while loading. Please reload this page.