From 9668f24277718022aab2a66e77c1a0c8bf62a9ca Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 12 Jan 2026 09:10:34 +0100 Subject: [PATCH 01/15] flux2-klein --- scripts/convert_flux2_to_diffusers.py | 62 +- src/diffusers/__init__.py | 2 + .../models/transformers/transformer_flux2.py | 37 +- src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/flux2/__init__.py | 2 + .../pipelines/flux2/pipeline_flux2.py | 6 +- .../pipelines/flux2/pipeline_flux2_klein.py | 892 ++++++++++++++++++ 7 files changed, 984 insertions(+), 21 deletions(-) create mode 100644 src/diffusers/pipelines/flux2/pipeline_flux2_klein.py diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py index 2973913fa215..d057fa491d68 100644 --- a/scripts/convert_flux2_to_diffusers.py +++ b/scripts/convert_flux2_to_diffusers.py @@ -44,7 +44,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--original_state_dict_repo_id", default=None, type=str) parser.add_argument("--vae_filename", default="flux2-vae.sft", type=str) -parser.add_argument("--dit_filename", default="flux-dev-dummy.sft", type=str) +parser.add_argument("--dit_filename", default="flux2-dev.safetensors", type=str) parser.add_argument("--vae", action="store_true") parser.add_argument("--dit", action="store_true") parser.add_argument("--vae_dtype", type=str, default="fp32") @@ -385,9 +385,9 @@ def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) -> def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: - if model_type == "test" or model_type == "dummy-flux2": + if model_type == "flux2-dev": config = { - "model_id": "diffusers-internal-dev/dummy-flux2", + "model_id": "black-forest-labs/FLUX.2-dev", "diffusers_config": { "patch_size": 1, "in_channels": 128, @@ -405,6 +405,53 @@ def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: } rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP + elif model_type == "klein-4b": + config = { + "model_id": "diffusers-internal-dev/dummy0115", + "diffusers_config": { + "patch_size": 1, + "in_channels": 128, + "num_layers": 5, + "num_single_layers": 20, + "attention_head_dim": 128, + "num_attention_heads": 24, + "joint_attention_dim": 7680, + "timestep_guidance_channels": 256, + "mlp_ratio": 3.0, + "axes_dims_rope": (32, 32, 32, 32), + "rope_theta": 2000, + "eps": 1e-6, + "guidance_embeds": False, + }, + } + rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP + + elif model_type == "klein-9b": + config = { + "model_id": "diffusers-internal-dev/dummy0115", + "diffusers_config": { + "patch_size": 1, + "in_channels": 128, + "num_layers": 8, + "num_single_layers": 24, + "attention_head_dim": 128, + "num_attention_heads": 32, + "joint_attention_dim": 12288, + "timestep_guidance_channels": 256, + "mlp_ratio": 3.0, + "axes_dims_rope": (32, 32, 32, 32), + "rope_theta": 2000, + "eps": 1e-6, + "guidance_embeds": False, + }, + } + rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP + + else: + raise ValueError(f"Unknown model_type: {model_type}. Choose from: flux2-dev, klein-4b, klein-9b") + return config, rename_dict, special_keys_remap @@ -447,7 +494,14 @@ def main(args): if args.dit: original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename) - transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, "test") + + if "klein-4b" in args.dit_filename: + model_type = "klein-4b" + elif "klein-9b" in args.dit_filename: + model_type = "klein-9b" + else: + model_type = "flux2-dev" + transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, model_type) if not args.full_pipe: dit_dtype = torch.bfloat16 if args.dit_dtype == "bf16" else torch.float32 transformer.to(dit_dtype).save_pretrained(f"{args.output_path}/transformer") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b2754093dd54..6aa0cee150a6 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -482,6 +482,7 @@ "EasyAnimateInpaintPipeline", "EasyAnimatePipeline", "Flux2Pipeline", + "Flux2KleinPipeline", "FluxControlImg2ImgPipeline", "FluxControlInpaintPipeline", "FluxControlNetImg2ImgPipeline", @@ -1209,6 +1210,7 @@ EasyAnimateInpaintPipeline, EasyAnimatePipeline, Flux2Pipeline, + Flux2KleinPipeline, FluxControlImg2ImgPipeline, FluxControlInpaintPipeline, FluxControlNetImg2ImgPipeline, diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 8032ec48c104..c22e6e4eeb8e 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -585,7 +585,13 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: class Flux2TimestepGuidanceEmbeddings(nn.Module): - def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False): + def __init__( + self, + in_channels: int = 256, + embedding_dim: int = 6144, + bias: bool = False, + guidance_embeds: bool = True, + ): super().__init__() self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) @@ -593,20 +599,24 @@ def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias ) - self.guidance_embedder = TimestepEmbedding( - in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias - ) + if guidance_embeds: + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + else: + self.guidance_embedder = None def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D) - guidance_proj = self.time_proj(guidance) - guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) - - time_guidance_emb = timesteps_emb + guidance_emb - - return time_guidance_emb + if guidance is not None and self.guidance_embedder is not None: + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) + time_guidance_emb = timesteps_emb + guidance_emb + return time_guidance_emb + else: + return timesteps_emb class Flux2Modulation(nn.Module): @@ -698,6 +708,7 @@ def __init__( axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), rope_theta: int = 2000, eps: float = 1e-6, + guidance_embeds: bool = True, ): super().__init__() self.out_channels = out_channels or in_channels @@ -708,7 +719,7 @@ def __init__( # 2. Combined timestep + guidance embedding self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( - in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False + in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False, guidance_embeds=guidance_embeds ) # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) @@ -815,7 +826,9 @@ def forward( # 1. Calculate timestep embedding and modulation parameters timestep = timestep.to(hidden_states.dtype) * 1000 - guidance = guidance.to(hidden_states.dtype) * 1000 + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 temb = self.time_guidance_embed(timestep, guidance) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 8fa7a4f98ccb..6a0aa0c52d6d 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -130,7 +130,7 @@ ] _import_structure["bria"] = ["BriaPipeline"] _import_structure["bria_fibo"] = ["BriaFiboPipeline"] - _import_structure["flux2"] = ["Flux2Pipeline"] + _import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"] _import_structure["flux"] = [ "FluxControlPipeline", "FluxControlInpaintPipeline", @@ -678,7 +678,7 @@ FluxPriorReduxPipeline, ReduxImageEncoder, ) - from .flux2 import Flux2Pipeline + from .flux2 import Flux2Pipeline, Flux2KleinPipeline from .glm_image import GlmImagePipeline from .hidream_image import HiDreamImagePipeline from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline diff --git a/src/diffusers/pipelines/flux2/__init__.py b/src/diffusers/pipelines/flux2/__init__.py index d986c9a63011..f6e1d5206630 100644 --- a/src/diffusers/pipelines/flux2/__init__.py +++ b/src/diffusers/pipelines/flux2/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_flux2"] = ["Flux2Pipeline"] + _import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -31,6 +32,7 @@ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_flux2 import Flux2Pipeline + from .pipeline_flux2_klein import Flux2KleinPipeline else: import sys diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index b54a43dd89a5..c01b7137e086 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -725,8 +725,8 @@ def guidance_scale(self): return self._guidance_scale @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs + def attention_kwargs(self): + return self._attention_kwargs @property def num_timesteps(self): @@ -975,7 +975,7 @@ def __call__( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, # B, text_seq_len, 4 img_ids=latent_image_ids, # B, image_seq_len, 4 - joint_attention_kwargs=self._attention_kwargs, + joint_attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py new file mode 100644 index 000000000000..78bfc3d13c16 --- /dev/null +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -0,0 +1,892 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from transformers import Qwen3ForCausalLM, Qwen2TokenizerFast + +from ...loaders import Flux2LoraLoaderMixin +from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import Flux2ImageProcessor +from .pipeline_output import Flux2PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Flux2Pipeline + + >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0] + >>> image.save("flux.png") + ``` +""" + +# Taken from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L251 +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin): + r""" + The Flux2 klein pipeline for text-to-image generation. + + Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2) + + Args: + transformer ([`Flux2Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLFlux2`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen3ForCausalLM`]): + [Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM) + tokenizer (`Qwen2TokenizerFast`): + Tokenizer of class + [Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLFlux2, + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + transformer: Flux2Transformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + + @staticmethod + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + hidden_states_layers: List[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + @staticmethod + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + def _prepare_image_ids( + image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (List[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: Optional[torch.Tensor] = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + return latents, latent_ids + + def prepare_image_latents( + self, + images: List[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and not self.transformer.config.guidance_embeds + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + guidance_scale (`float`, *optional*, defaults to 1.0): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. prepare text embeddings + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if self.do_classifier_free_guidance: + negative_prompt_embeds, negative_text_ids = self.encode_prompt( + prompt="", + prompt_embeds=None, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. process images + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + height = height or image_height + width = width or image_width + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 5. prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + image_latents = None + image_latent_ids = None + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, image_seq_len, C) + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 + joint_attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1) :] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self._attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1) :] + noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Flux2PipelineOutput(images=image) From ba3aaef4ec2dbc1cecab714594f667d03e057756 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 12 Jan 2026 07:08:19 -1000 Subject: [PATCH 02/15] Apply suggestions from code review Co-authored-by: Sayak Paul --- src/diffusers/pipelines/flux2/pipeline_flux2_klein.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index 78bfc3d13c16..5603fef0c502 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -44,9 +44,9 @@ Examples: ```py >>> import torch - >>> from diffusers import Flux2Pipeline + >>> from diffusers import Flux2KleinPipeline - >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16) + >>> pipe = Flux2KleinPipeline.from_pretrained("TODO", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A cat holding a sign that says hello world" >>> # Depending on the variant being used, the pipeline call will slightly vary. @@ -56,8 +56,7 @@ ``` """ -# Taken from -# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L251 +# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empiricial_mu def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: a1, b1 = 8.73809524e-05, 1.89833333 a2, b2 = 0.00016927, 0.45666666 From 24fb03a1504fcae489d3ff90610c85ecd06a1c7e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 13 Jan 2026 08:28:11 +0530 Subject: [PATCH 03/15] Klein tests (#2) * tests * up * tests * up --- .../pipelines/flux2/pipeline_flux2_klein.py | 12 +- .../flux2/test_pipeline_flux2_klein.py | 186 ++++++++++++++++++ 2 files changed, 197 insertions(+), 1 deletion(-) create mode 100644 tests/pipelines/flux2/test_pipeline_flux2_klein.py diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index 5603fef0c502..cd8a1bc9ec2e 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -417,6 +417,7 @@ def encode_prompt( num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (9, 18, 27), ): device = device or self._execution_device @@ -432,6 +433,7 @@ def encode_prompt( prompt=prompt, device=device, max_sequence_length=max_sequence_length, + hidden_states_layers=text_encoder_out_layers ) batch_size, seq_len, _ = prompt_embeds.shape @@ -604,6 +606,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (9, 18, 27) ): r""" Function invoked when calling the pipeline for generation. @@ -666,6 +669,8 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + text_encoder_out_layers (`Tuple[int]`): + Layer indices to use in the `text_encoder` to derive the final prompt embeddings. Examples: @@ -706,15 +711,20 @@ def __call__( device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers ) if self.do_classifier_free_guidance: + negative_prompt = "" + if prompt is not None and isinstance(prompt, list): + negative_prompt = [negative_prompt] * len(prompt) negative_prompt_embeds, negative_text_ids = self.encode_prompt( - prompt="", + prompt=negative_prompt, prompt_embeds=None, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, ) # 4. process images diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein.py b/tests/pipelines/flux2/test_pipeline_flux2_klein.py new file mode 100644 index 000000000000..d634d251ac20 --- /dev/null +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein.py @@ -0,0 +1,186 @@ +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM + +from diffusers import ( + AutoencoderKLFlux2, + FlowMatchEulerDiscreteScheduler, + Flux2KleinPipeline, + Flux2Transformer2DModel, +) + +from ...testing_utils import ( + torch_device +) +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fused_layers_exist +) + + +class Flux2KleinPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Flux2KleinPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"]) + batch_params = frozenset(["prompt"]) + + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + supports_dduf = False + + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): + torch.manual_seed(0) + transformer = Flux2Transformer2DModel( + patch_size=1, + in_channels=4, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=16, + timestep_guidance_channels=256, + axes_dims_rope=[4, 4, 4, 4], + guidance_embeds=False + ) + + # Create minimal Qwen3 config + config = Qwen3Config( + intermediate_size=16, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + vocab_size=151936, + max_position_embeddings=512, + ) + torch.manual_seed(0) + text_encoder = Qwen3ForCausalLM(config) + + # Use a simple tokenizer for testing + tokenizer = Qwen2TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + torch.manual_seed(0) + vae = AutoencoderKLFlux2( + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "a dog is dancing", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.0, + "height": 8, + "width": 8, + "max_sequence_length": 64, + "output_type": "np", + "text_encoder_out_layers": (1,) + } + return inputs + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + pipe.transformer.fuse_qkv_projections() + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), + ) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + self.assertTrue( + np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), + ("Fusion of QKV projections shouldn't affect the outputs."), + ) + self.assertTrue( + np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), + ("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."), + ) + self.assertTrue( + np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), + ("Original outputs should match when fused QKV projections are disabled."), + ) + + def test_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + self.assertEqual( + (output_height, output_width), + (expected_height, expected_width), + f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}", + ) + + def test_image_input(self): + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()).to(device) + inputs = self.get_dummy_inputs(device) + + inputs["image"] = Image.new("RGB", (64, 64)) + image = pipe(**inputs).images.flatten() + generated_slice = np.concatenate([image[:8], image[-8:]]) + # fmt: off + expected_slice = np.array( + [ + 0.8255048 , 0.66054785, 0.6643694 , 0.67462724, 0.5494932 , 0.3480271 , 0.52535003, 0.44510138, 0.23549396, 0.21372932, 0.21166152, 0.63198495, 0.49942136, 0.39147034, 0.49156153, 0.3713916 + ] + ) + # fmt: on + assert np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4) + + @unittest.skip("Needs to be revisited") + def test_encode_prompt_works_in_isolation(self): + pass \ No newline at end of file From ad58f0af2ffefe05db212c39a7fc53ce3be416e8 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 14 Jan 2026 02:19:09 +0100 Subject: [PATCH 04/15] support step-distilled --- .../pipelines/flux2/pipeline_flux2_klein.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index cd8a1bc9ec2e..d2229598a49f 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -179,6 +179,7 @@ def __init__( text_encoder: Qwen3ForCausalLM, tokenizer: Qwen2TokenizerFast, transformer: Flux2Transformer2DModel, + is_distilled: bool = False, ): super().__init__() @@ -189,6 +190,9 @@ def __init__( scheduler=scheduler, transformer=transformer, ) + + self.register_to_config(is_distilled=is_distilled) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this @@ -531,6 +535,7 @@ def check_inputs( width, prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + guidance_scale=None, ): if ( height is not None @@ -561,13 +566,16 @@ def check_inputs( elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if guidance_scale > 1.0 and self.config.is_distilled: + logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.") + @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1 and not self.transformer.config.guidance_embeds + return self._guidance_scale > 1 and not self.config.is_distilled @property def attention_kwargs(self): @@ -687,6 +695,7 @@ def __call__( width=width, prompt_embeds=prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + guidance_scale=guidance_scale, ) self._guidance_scale = guidance_scale @@ -794,12 +803,6 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - if self.transformer.config.guidance_embeds: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None - # 7. Denoising loop # We set the index here to remove DtoH sync, helpful especially during compilation. @@ -825,7 +828,7 @@ def __call__( noise_pred = self.transformer( hidden_states=latent_model_input, # (B, image_seq_len, C) timestep=timestep / 1000, - guidance=guidance, + guidance=None, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, # B, text_seq_len, 4 img_ids=latent_image_ids, # B, image_seq_len, 4 @@ -840,7 +843,7 @@ def __call__( neg_noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, - guidance=guidance, + guidance=None, encoder_hidden_states=negative_prompt_embeds, txt_ids=negative_text_ids, img_ids=latent_image_ids, From bc1a19ef3ac4a3fed8ab44e3f32013765158f968 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 13 Jan 2026 15:23:29 -1000 Subject: [PATCH 05/15] Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- .../pipelines/flux2/pipeline_flux2_klein.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index d2229598a49f..3e11efd1cdf0 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -260,6 +260,7 @@ def _get_qwen3_prompt_embeds( return prompt_embeds @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids def _prepare_text_ids( x: torch.Tensor, # (B, L, D) or (L, D) t_coord: Optional[torch.Tensor] = None, @@ -279,6 +280,7 @@ def _prepare_text_ids( return torch.stack(out_ids) @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids def _prepare_latent_ids( latents: torch.Tensor, # (B, C, H, W) ): @@ -311,6 +313,7 @@ def _prepare_latent_ids( return latent_ids @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids def _prepare_image_ids( image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] scale: int = 10, @@ -361,6 +364,7 @@ def _prepare_image_ids( return image_latent_ids @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents def _patchify_latents(latents): batch_size, num_channels_latents, height, width = latents.shape latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) @@ -377,6 +381,7 @@ def _unpatchify_latents(latents): return latents @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents def _pack_latents(latents): """ pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) @@ -388,6 +393,7 @@ def _pack_latents(latents): return latents @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: """ using position ids to scatter tokens into place @@ -448,6 +454,7 @@ def encode_prompt( text_ids = text_ids.to(device) return prompt_embeds, text_ids + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if image.ndim != 4: raise ValueError(f"Expected image dims 4, got {image.ndim}.") @@ -461,6 +468,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): return image_latents + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents def prepare_latents( self, batch_size, @@ -494,6 +502,7 @@ def prepare_latents( latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] return latents, latent_ids + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents def prepare_image_latents( self, images: List[torch.Tensor], @@ -880,18 +889,17 @@ def __call__( self._current_timestep = None + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) if output_type == "latent": image = latents else: - latents = self._unpack_latents_with_ids(latents, latent_ids) - - latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) - latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( - latents.device, latents.dtype - ) - latents = latents * latents_bn_std + latents_bn_mean - latents = self._unpatchify_latents(latents) - image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) From f2945314be68b3393e8216dfdd978ec7378e3886 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 13 Jan 2026 15:36:28 -1000 Subject: [PATCH 06/15] Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/pipelines/flux2/pipeline_flux2_klein.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index 3e11efd1cdf0..6fb36de86c19 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -373,6 +373,7 @@ def _patchify_latents(latents): return latents @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents def _unpatchify_latents(latents): batch_size, num_channels_latents, height, width = latents.shape latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) From 637e983b8aa301166ed535e581269d6bffa9d3aa Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 14 Jan 2026 04:12:56 +0100 Subject: [PATCH 07/15] doc string etc --- .../pipelines/flux2/pipeline_flux2_klein.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index 6fb36de86c19..df058ef9f7d9 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -51,8 +51,8 @@ >>> prompt = "A cat holding a sign that says hello world" >>> # Depending on the variant being used, the pipeline call will slightly vary. >>> # Refer to the pipeline documentation for more details. - >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0] - >>> image.save("flux.png") + >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=4.0).images[0] + >>> image.save("flux2_output.png") ``` """ @@ -618,6 +618,7 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[Union[str, List[str]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -639,12 +640,12 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - guidance_scale (`float`, *optional*, defaults to 1.0): - Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages - a model to generate images more aligned with `prompt` at the expense of lower image quality. - - Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to - the [paper](https://huggingface.co/papers/2210.03142) to learn more. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models, `guidance_scale` is ignored. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -668,6 +669,9 @@ def __call__( prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Note that "" is used as the negative prompt in this pipeline. + If not provided, will be generated from "". output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -739,7 +743,7 @@ def __call__( negative_prompt = [negative_prompt] * len(prompt) negative_prompt_embeds, negative_text_ids = self.encode_prompt( prompt=negative_prompt, - prompt_embeds=None, + prompt_embeds=negative_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, From e18c1342304d1686639e4bde0d6294400cf24de4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 14 Jan 2026 04:14:12 +0100 Subject: [PATCH 08/15] style --- src/diffusers/__init__.py | 4 ++-- .../models/transformers/transformer_flux2.py | 17 ++++++++++------- src/diffusers/pipelines/__init__.py | 2 +- .../pipelines/flux2/pipeline_flux2_klein.py | 16 +++++++--------- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6aa0cee150a6..5617e9b65da6 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -481,8 +481,8 @@ "EasyAnimateControlPipeline", "EasyAnimateInpaintPipeline", "EasyAnimatePipeline", - "Flux2Pipeline", "Flux2KleinPipeline", + "Flux2Pipeline", "FluxControlImg2ImgPipeline", "FluxControlInpaintPipeline", "FluxControlNetImg2ImgPipeline", @@ -1209,8 +1209,8 @@ EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, EasyAnimatePipeline, - Flux2Pipeline, Flux2KleinPipeline, + Flux2Pipeline, FluxControlImg2ImgPipeline, FluxControlInpaintPipeline, FluxControlNetImg2ImgPipeline, diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index c22e6e4eeb8e..9cadfcefc497 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -586,10 +586,10 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: class Flux2TimestepGuidanceEmbeddings(nn.Module): def __init__( - self, - in_channels: int = 256, - embedding_dim: int = 6144, - bias: bool = False, + self, + in_channels: int = 256, + embedding_dim: int = 6144, + bias: bool = False, guidance_embeds: bool = True, ): super().__init__() @@ -601,8 +601,8 @@ def __init__( if guidance_embeds: self.guidance_embedder = TimestepEmbedding( - in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias - ) + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) else: self.guidance_embedder = None @@ -719,7 +719,10 @@ def __init__( # 2. Combined timestep + guidance embedding self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( - in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False, guidance_embeds=guidance_embeds + in_channels=timestep_guidance_channels, + embedding_dim=self.inner_dim, + bias=False, + guidance_embeds=guidance_embeds, ) # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 6a0aa0c52d6d..ca77ee9036e3 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -678,7 +678,7 @@ FluxPriorReduxPipeline, ReduxImageEncoder, ) - from .flux2 import Flux2Pipeline, Flux2KleinPipeline + from .flux2 import Flux2KleinPipeline, Flux2Pipeline from .glm_image import GlmImagePipeline from .hidream_image import HiDreamImagePipeline from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index df058ef9f7d9..327714751862 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -18,7 +18,7 @@ import numpy as np import PIL import torch -from transformers import Qwen3ForCausalLM, Qwen2TokenizerFast +from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM from ...loaders import Flux2LoraLoaderMixin from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel @@ -56,6 +56,7 @@ ``` """ + # Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empiricial_mu def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: a1, b1 = 8.73809524e-05, 1.89833333 @@ -200,7 +201,6 @@ def __init__( self.tokenizer_max_length = 512 self.default_sample_size = 128 - @staticmethod def _get_qwen3_prompt_embeds( text_encoder: Qwen3ForCausalLM, @@ -220,7 +220,6 @@ def _get_qwen3_prompt_embeds( all_attention_masks = [] for single_prompt in prompt: - messages = [{"role": "user", "content": single_prompt}] text = tokenizer.apply_chat_template( messages, @@ -420,7 +419,6 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch return torch.stack(x_list, dim=0) - def encode_prompt( self, prompt: Union[str, List[str]], @@ -444,7 +442,7 @@ def encode_prompt( prompt=prompt, device=device, max_sequence_length=max_sequence_length, - hidden_states_layers=text_encoder_out_layers + hidden_states_layers=text_encoder_out_layers, ) batch_size, seq_len, _ = prompt_embeds.shape @@ -625,7 +623,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - text_encoder_out_layers: Tuple[int] = (9, 18, 27) + text_encoder_out_layers: Tuple[int] = (9, 18, 27), ): r""" Function invoked when calling the pipeline for generation. @@ -645,7 +643,8 @@ def __call__( Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models, `guidance_scale` is ignored. + the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models, + `guidance_scale` is ignored. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -734,7 +733,7 @@ def __call__( device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, - text_encoder_out_layers=text_encoder_out_layers + text_encoder_out_layers=text_encoder_out_layers, ) if self.do_classifier_free_guidance: @@ -817,7 +816,6 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - # 7. Denoising loop # We set the index here to remove DtoH sync, helpful especially during compilation. # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 From 0c0730af8a5b4347ed660480bb0177a129f8188d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 14 Jan 2026 04:14:24 +0100 Subject: [PATCH 09/15] more --- .../flux2/test_pipeline_flux2_klein.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein.py b/tests/pipelines/flux2/test_pipeline_flux2_klein.py index d634d251ac20..8ed9bf3d1e91 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2_klein.py +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein.py @@ -12,13 +12,8 @@ Flux2Transformer2DModel, ) -from ...testing_utils import ( - torch_device -) -from ..test_pipelines_common import ( - PipelineTesterMixin, - check_qkv_fused_layers_exist -) +from ...testing_utils import torch_device +from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist class Flux2KleinPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @@ -44,7 +39,7 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): joint_attention_dim=16, timestep_guidance_channels=256, axes_dims_rope=[4, 4, 4, 4], - guidance_embeds=False + guidance_embeds=False, ) # Create minimal Qwen3 config @@ -61,7 +56,9 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): text_encoder = Qwen3ForCausalLM(config) # Use a simple tokenizer for testing - tokenizer = Qwen2TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + tokenizer = Qwen2TokenizerFast.from_pretrained( + "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + ) torch.manual_seed(0) vae = AutoencoderKLFlux2( @@ -103,7 +100,7 @@ def get_dummy_inputs(self, device, seed=0): "width": 8, "max_sequence_length": 64, "output_type": "np", - "text_encoder_out_layers": (1,) + "text_encoder_out_layers": (1,), } return inputs @@ -183,4 +180,4 @@ def test_image_input(self): @unittest.skip("Needs to be revisited") def test_encode_prompt_works_in_isolation(self): - pass \ No newline at end of file + pass From eb84d50546743535b8dee735d8adb45d624ed937 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 14 Jan 2026 04:18:47 +0100 Subject: [PATCH 10/15] copies --- .../pipelines/flux2/pipeline_flux2_klein.py | 2 +- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index 327714751862..59ae687dd06c 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -57,7 +57,7 @@ """ -# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empiricial_mu +# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: a1, b1 = 8.73809524e-05, 1.89833333 a2, b2 = 0.00016927, 0.45666666 diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 40c9bd5aa1ab..d48981f380ac 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -947,6 +947,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Flux2KleinPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Flux2Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 439abe9d72bc7d7a75353d2e05ce315121b4443f Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Thu, 15 Jan 2026 14:34:37 +0200 Subject: [PATCH 11/15] klein lora training scripts (#3) * initial commit * initial commit * remove remote text encoder * initial commit * initial commit * initial commit * revert * img2img fix * text encoder + tokenizer * text encoder + tokenizer * update readme * guidance * guidance * guidance * test * test * revert changes not needed for the non klein model * Update examples/dreambooth/train_dreambooth_lora_flux2_klein.py Co-authored-by: Sayak Paul * fix guidance * fix validation * fix validation * fix validation * fix path * space --------- Co-authored-by: Sayak Paul --- examples/dreambooth/README_flux2.md | 129 +- .../test_dreambooth_lora_flux2_klein.py | 262 +++ .../train_dreambooth_lora_flux2_img2img.py | 2 +- .../train_dreambooth_lora_flux2_klein.py | 1944 +++++++++++++++++ ...ain_dreambooth_lora_flux2_klein_img2img.py | 1890 ++++++++++++++++ 5 files changed, 4204 insertions(+), 23 deletions(-) create mode 100644 examples/dreambooth/test_dreambooth_lora_flux2_klein.py create mode 100644 examples/dreambooth/train_dreambooth_lora_flux2_klein.py create mode 100644 examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py diff --git a/examples/dreambooth/README_flux2.md b/examples/dreambooth/README_flux2.md index 7b6df9c0202d..ad5d61f1f9e2 100644 --- a/examples/dreambooth/README_flux2.md +++ b/examples/dreambooth/README_flux2.md @@ -1,14 +1,22 @@ -# DreamBooth training example for FLUX.2 [dev] +# DreamBooth training example for FLUX.2 [dev] and FLUX 2 [klein] [DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept. +[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. + +The `train_dreambooth_lora_flux2.py`, `train_dreambooth_lora_flux2_klein.py` scripts shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://huggingface.co/black-forest-labs/FLUX.2-dev) and [FLUX 2 [klein]](https://huggingface.co/black-forest-labs/FLUX.2-klein). -The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2). +> [!NOTE] +> **Model Variants** +> +> We support two FLUX model families: +> - **FLUX.2 [dev]**: The full-size model using Mistral Small 3.1 as the text encoder. Very capable but memory intensive. +> - **FLUX 2 [klein]**: Available in 4B and 9B parameter variants, using Qwen VL as the text encoder. Much more memory efficient and suitable for consumer hardware. > [!NOTE] > **Memory consumption** > -> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - -> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. below we provide some tips and tricks to reduce memory consumption during training. +> FLUX.2 [dev] can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - +> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. FLUX 2 [klein] models (4B and 9B) are significantly more memory efficient alternatives. Below we provide some tips and tricks to reduce memory consumption during training. > For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX: > 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX2.md) @@ -17,7 +25,7 @@ The `train_dreambooth_lora_flux2.py` script shows how to implement the training > [!NOTE] > **Gated model** > -> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: +> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you've accepted the gate. Use the command below to log in: ```bash hf auth login @@ -88,23 +96,32 @@ snapshot_download( This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform. -As mentioned, Flux2 LoRA training is *very* memory intensive. Here are memory optimizations we can use (some still experimental) for a more memory efficient training: +As mentioned, Flux2 LoRA training is *very* memory intensive (especially for FLUX.2 [dev]). Here are memory optimizations we can use (some still experimental) for a more memory efficient training: ## Memory Optimizations > [!NOTE] many of these techniques complement each other and can be used together to further reduce memory consumption. > However some techniques may be mutually exclusive so be sure to check before launching a training run. + ### Remote Text Encoder -Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API. +FLUX.2 [dev] uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API. This way, the text encoder model is not loaded into memory during training. + +> [!IMPORTANT] +> **Remote text encoder is only supported for FLUX.2 [dev]**. FLUX 2 [klein] models use the Qwen VL text encoder and do not support remote text encoding. + > [!NOTE] > to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`. + ### FSDP Text Encoder -Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings. +FLUX.2 [dev] uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings. This way, it distributes the memory cost across multiple nodes. + ### CPU Offloading To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed. + ### Latent Caching Pre-encode the training images with the vae, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`. + ### QLoRA: Low Precision Training with Quantization Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags: - **FP8 training** with `torchao`: @@ -114,22 +131,29 @@ enable FP8 training by passing `--do_fp8_training`. - **NF4 training** with `bitsandbytes`: Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing: `--bnb_quantization_config_path` to enable 4-bit NF4 quantization. + ### Gradient Checkpointing and Accumulation * `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass. by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs. * with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass. Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass. + ### 8-bit-Adam Optimizer When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so. + ### Image Resolution An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this. Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions. + ### Precision of saved LoRA layers By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well. This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`. +## Training Examples +### FLUX.2 [dev] Training +To perform DreamBooth with LoRA on FLUX.2 [dev], run: ```bash export MODEL_NAME="black-forest-labs/FLUX.2-dev" export INSTANCE_DIR="dog" @@ -161,13 +185,84 @@ accelerate launch train_dreambooth_lora_flux2.py \ --push_to_hub ``` +### FLUX 2 [klein] Training + +FLUX 2 [klein] models are more memory efficient alternatives available in 4B and 9B parameter variants. They use the Qwen VL text encoder instead of Mistral Small 3.1. + +> [!NOTE] +> The `--remote_text_encoder` flag is **not supported** for FLUX 2 [klein] models. The Qwen VL text encoder must be loaded locally, but offloading is still supported. + +**FLUX 2 [klein] 4B:** + +```bash +export MODEL_NAME="black-forest-labs/FLUX.2-klein-4B" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="trained-flux2-klein-4b" + +accelerate launch train_dreambooth_lora_flux2_klein.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --do_fp8_training \ + --gradient_checkpointing \ + --cache_latents \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --guidance_scale=1 \ + --use_8bit_adam \ + --gradient_accumulation_steps=4 \ + --optimizer="adamW" \ + --learning_rate=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=100 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` + +**FLUX 2 [klein] 9B:** + +```bash +export MODEL_NAME="black-forest-labs/FLUX.2-klein-9B" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="trained-flux2-klein-9b" + +accelerate launch train_dreambooth_lora_flux2_klein.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --do_fp8_training \ + --gradient_checkpointing \ + --cache_latents \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --guidance_scale=1 \ + --use_8bit_adam \ + --gradient_accumulation_steps=4 \ + --optimizer="adamW" \ + --learning_rate=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=100 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` + To better track our training experiments, we're using the following flags in the command above: * `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before. * `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. > [!NOTE] -> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases. +> If you want to train using long prompts, you can use `--max_sequence_length` to set the token limit. Note that this will use more resources and may slow down the training in some cases. ### FSDP on the transformer By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to: @@ -189,12 +284,6 @@ fsdp_config: fsdp_cpu_ram_efficient_loading: false ``` -## LoRA + DreamBooth - -[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. - -Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. - ### Prodigy Optimizer Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence. By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers). @@ -206,8 +295,6 @@ to use prodigy, first make sure to install the prodigyopt library: `pip install > [!TIP] > When using prodigy it's generally good practice to set- `--learning_rate=1.0` -To perform DreamBooth with LoRA, run: - ```bash export MODEL_NAME="black-forest-labs/FLUX.2-dev" export INSTANCE_DIR="dog" @@ -271,13 +358,10 @@ the exact modules for LoRA training. Here are some examples of target modules yo > keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights. - ## Training Image-to-Image Flux.2 lets us perform image editing as well as image generation. We provide a simple script for image-to-image(I2I) LoRA fine-tuning in [train_dreambooth_lora_flux2_img2img.py](./train_dreambooth_lora_flux2_img2img.py) for both T2I and I2I. The optimizations discussed above apply this script, too. -**important** - **Important** To make sure you can successfully run the latest version of the image-to-image example script, we highly recommend installing from source, specifically from the commit mentioned below. To do this, execute the following steps in a new virtual environment: @@ -334,5 +418,6 @@ we've added aspect ratio bucketing support which allows training on images with To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as: `--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672" -` -Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗 + + +Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗 \ No newline at end of file diff --git a/examples/dreambooth/test_dreambooth_lora_flux2_klein.py b/examples/dreambooth/test_dreambooth_lora_flux2_klein.py new file mode 100644 index 000000000000..0e5506e1a3eb --- /dev/null +++ b/examples/dreambooth/test_dreambooth_lora_flux2_klein.py @@ -0,0 +1,262 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import sys +import tempfile + +import safetensors + +from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothLoRAFlux2Klein(ExamplesTestsAccelerate): + instance_data_dir = "docs/source/en/imgs" + instance_prompt = "dog" + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein" + script_path = "examples/dreambooth/train_dreambooth_lora_flux2_klein.py" + transformer_layer_type = "single_transformer_blocks.0.attn.to_qkv_mlp_proj" + + def test_dreambooth_lora_flux2(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --max_sequence_length 8 + --text_encoder_out_layers 1 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --max_sequence_length 8 + --text_encoder_out_layers 1 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_layers(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lora_layers {self.transformer_layer_type} + --lr_scheduler constant + --lr_warmup_steps 0 + --max_sequence_length 8 + --text_encoder_out_layers 1 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. In this test, we only params of + # transformer.single_transformer_blocks.0.attn.to_k should be in the state dict + starts_with_transformer = all( + key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys() + ) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --max_sequence_length 8 + --checkpointing_steps=2 + --text_encoder_out_layers 1 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --checkpointing_steps=2 + --max_sequence_length 8 + --text_encoder_out_layers 1 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"}) + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=8 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 + --max_sequence_length 8 + --text_encoder_out_layers 1 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) + + def test_dreambooth_lora_with_metadata(self): + # Use a `lora_alpha` that is different from `rank`. + lora_alpha = 8 + rank = 4 + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --lora_alpha={lora_alpha} + --rank={rank} + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --max_sequence_length 8 + --text_encoder_out_layers 1 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + self.assertTrue(os.path.isfile(state_dict_file)) + + # Check if the metadata was properly serialized. + with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f: + metadata = f.metadata() or {} + + metadata.pop("format", None) + raw = metadata.get(LORA_ADAPTER_METADATA_KEY) + if raw: + raw = json.loads(raw) + + loaded_lora_alpha = raw["transformer.lora_alpha"] + self.assertTrue(loaded_lora_alpha == lora_alpha) + loaded_lora_rank = raw["transformer.r"] + self.assertTrue(loaded_lora_rank == rank) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 5af0906642bf..16a3863c881d 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -127,7 +127,7 @@ def save_model_card( ) model_description = f""" -# Flux DreamBooth LoRA - {repo_id} +# Flux.2 DreamBooth LoRA - {repo_id} diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py new file mode 100644 index 000000000000..0141d61d8594 --- /dev/null +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -0,0 +1,1944 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "diffusers @ git+https://github.com/huggingface/diffusers.git", +# "torch>=2.0.0", +# "accelerate>=0.31.0", +# "transformers>=4.41.2", +# "ftfy", +# "tensorboard", +# "Jinja2", +# "peft>=0.11.1", +# "sentencepiece", +# "torchvision", +# "datasets", +# "bitsandbytes", +# "prodigyopt", +# ] +# /// + +import argparse +import copy +import itertools +import json +import logging +import math +import os +import random +import shutil +import warnings +from contextlib import nullcontext +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torch.utils.data.sampler import BatchSampler +from torchvision import transforms +from torchvision.transforms import functional as TF +from tqdm.auto import tqdm +from transformers import Qwen3ForCausalLM, Qwen2TokenizerFast + +import diffusers +from diffusers import ( + AutoencoderKLFlux2, + BitsAndBytesConfig, + FlowMatchEulerDiscreteScheduler, + Flux2KleinPipeline, + Flux2Transformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + _collate_lora_metadata, + _to_cpu_contiguous, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + find_nearest_bucket, + free_memory, + get_fsdp_kwargs_from_accelerator, + offload_models, + parse_buckets_string, + wrap_with_fsdp, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available +from diffusers.utils.torch_utils import is_compiled_module + + +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.37.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, + quant_training=None, +): + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + + model_description = f""" +# Flux.2 [Klein] DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md). + +Quant training? {quant_training} + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.2", torch_dtype=torch.bfloat16).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "flux2-klein", + "flux2-klein-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, +): + args.num_validation_images = args.num_validation_images if args.num_validation_images else 1 + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(dtype=torch_dtype) + pipeline.enable_model_cpu_offload() + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None + autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() + + images = [] + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=pipeline_args["prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + free_memory() + + return images + + +def module_filter_fn(mod: torch.nn.Module, fqn: str): + # don't convert the output module + if fqn == "proj_out": + return False + # don't convert linear modules with weight dimensions not divisible by 16 + if isinstance(mod, torch.nn.Linear): + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + return True + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--bnb_quantization_config_path", + type=str, + default=None, + help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.", + ) + parser.add_argument( + "--do_fp8_training", + action="store_true", + help="if we are doing FP8 training.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=512, + help="Maximum sequence length to use with with the T5 text encoder", + ) + parser.add_argument( + "--text_encoder_out_layers", + type=int, + nargs="+", + default=[10, 20, 30], + help="Text encoder hidden layers to compute the final text embeddings.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--skip_final_inference", + default=False, + action="store_true", + help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.", + ) + parser.add_argument( + "--final_validation_prompt", + type=str, + default=None, + help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", + ) + parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") + + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--aspect_ratio_buckets", + type=str, + default=None, + help=( + "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. " + "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'" + "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored." + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only' + ), + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) + parser.add_argument( + "--offload", + action="store_true", + help="Whether to offload the VAE and the text encoder to CPU when they are not used.", + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + if args.do_fp8_training and args.bnb_quantization_config_path: + raise ValueError("Both `do_fp8_training` and `bnb_quantization_config_path` cannot be passed.") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + buckets=None, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + + self.buckets = buckets + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + + self.pixel_values = [] + for i, image in enumerate(self.instance_images): + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + + width, height = image.size + + # Find the closest bucket + bucket_idx = find_nearest_bucket(height, width, self.buckets) + target_height, target_width = self.buckets[bucket_idx] + self.size = (target_height, target_width) + + # based on the bucket assignment, define the transformations + image = self.train_transform( + image, + size=self.size, + center_crop=args.center_crop, + random_flip=args.random_flip, + ) + self.pixel_values.append((image, bucket_idx)) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + example["bucket_idx"] = bucket_idx + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # custom prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + def train_transform(self, image, size=(224, 224), center_crop=False, random_flip=False): + # 1. Resize (deterministic) + resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + image = resize(image) + + # 2. Crop: either center or SAME random crop + if center_crop: + crop = transforms.CenterCrop(size) + image = crop(image) + else: + # get_params returns (i, j, h, w) + i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size) + image = TF.crop(image, i, j, h, w) + + # 3. Random horizontal flip with the SAME coin flip + if random_flip: + do_flip = random.random() < 0.5 + if do_flip: + image = TF.hflip(image) + + # 4. ToTensor + Normalize (deterministic) + to_tensor = transforms.ToTensor() + normalize = transforms.Normalize([0.5], [0.5]) + image = normalize(to_tensor(image)) + + return image + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class BucketBatchSampler(BatchSampler): + def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last)) + + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + # Group indices by bucket + self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] + for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values): + self.bucket_indices[bucket_idx].append(idx) + + self.sampler_len = 0 + self.batches = [] + + # Pre-generate batches for each bucket + for indices_in_bucket in self.bucket_indices: + # Shuffle indices within the bucket + random.shuffle(indices_in_bucket) + # Create batches + for i in range(0, len(indices_in_bucket), self.batch_size): + batch = indices_in_bucket[i : i + self.batch_size] + if len(batch) < self.batch_size and self.drop_last: + continue # Skip partial batch if drop_last is True + self.batches.append(batch) + self.sampler_len += 1 # Count the number of batches + + def __iter__(self): + # Shuffle the order of the batches each epoch + random.shuffle(self.batches) + for batch in self.batches: + yield batch + + def __len__(self): + return self.sampler_len + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `hf auth login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + if args.do_fp8_training: + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + + pipeline = Flux2KleinPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype): + images = pipeline(prompt=example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + free_memory() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load the tokenizers + tokenizer = Qwen2TokenizerFast.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler", + revision=args.revision, + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + vae = AutoencoderKLFlux2.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( + accelerator.device + ) + + quantization_config = None + if args.bnb_quantization_config_path is not None: + with open(args.bnb_quantization_config_path, "r") as f: + config_kwargs = json.load(f) + if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]: + config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype + quantization_config = BitsAndBytesConfig(**config_kwargs) + + transformer = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + quantization_config=quantization_config, + torch_dtype=weight_dtype, + ) + if args.bnb_quantization_config_path is not None: + transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) + + text_encoder = Qwen3ForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder.requires_grad_(False) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + transformer.set_attention_backend("_native_npu") + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ") + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype} + # flux vae is stable in bf16 so load it in weight_dtype to reduce memory + vae.to(**to_kwargs) + # we never offload the transformer to CPU, so we can just use the accelerator device + transformer_to_kwargs = ( + {"device": accelerator.device} + if args.bnb_quantization_config_path is not None + else {"device": accelerator.device, "dtype": weight_dtype} + ) + + is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None + if not is_fsdp: + transformer.to(**transformer_to_kwargs) + + if args.do_fp8_training: + convert_to_float8_training( + transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) + ) + + + text_encoder.to(**to_kwargs) + # Initialize a text encoding pipeline and keep it to CPU for now. + text_encoding_pipeline = Flux2KleinPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=None, + transformer=None, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=None, + revision=args.revision, + ) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + + # now we will add new LoRA weights the transformer layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + transformer_cls = type(unwrap_model(transformer)) + + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None + + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") + + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(model) if is_fsdp else None + + # 3) Only main process materializes the LoRA state dict + transformer_lora_layers_to_save = None + if accelerator.is_main_process: + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model, + **peft_kwargs, + ) + + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + Flux2KleinPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + if not is_fsdp: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = Flux2KleinPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + if args.aspect_ratio_buckets is not None: + buckets = parse_buckets_string(args.aspect_ratio_buckets) + else: + buckets = [(args.resolution, args.resolution)] + logger.info(f"Using parsed aspect ratio buckets: {buckets}") + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + buckets=buckets, + ) + batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + def compute_text_embeddings(prompt, text_encoding_pipeline): + with torch.no_grad(): + prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt( + prompt=prompt, + max_sequence_length=args.max_sequence_length, + text_encoder_out_layers=args.text_encoder_out_layers, + ) + return prompt_embeds, text_ids + + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not train_dataset.custom_instance_prompts: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoding_pipeline + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + class_prompt_hidden_states, class_text_ids = compute_text_embeddings( + args.class_prompt, text_encoding_pipeline + ) + validation_embeddings = {} + if args.validation_prompt is not None: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + (validation_embeddings["prompt_embeds"], validation_embeddings["text_ids"]) = compute_text_embeddings( + args.validation_prompt, text_encoding_pipeline + ) + + # Init FSDP for text encoder + if args.fsdp_text_encoder: + fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator) + text_encoder_fsdp = wrap_with_fsdp( + model=text_encoding_pipeline.text_encoder, + device=accelerator.device, + offload=args.offload, + limit_all_gathers=True, + use_orig_params=True, + fsdp_kwargs=fsdp_kwargs, + ) + + text_encoding_pipeline.text_encoder = text_encoder_fsdp + dist.barrier() + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + if not train_dataset.custom_instance_prompts: + prompt_embeds = instance_prompt_hidden_states + text_ids = instance_text_ids + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + text_ids = torch.cat([text_ids, class_text_ids], dim=0) + + # if cache_latents is set to True, we encode images to latents and store them. + # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided + # we encode them in advance as well. + precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + if precompute_latents: + prompt_embeds_cache = [] + text_ids_cache = [] + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + if args.cache_latents: + with offload_models(vae, device=accelerator.device, offload=args.offload): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=vae.dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + if train_dataset.custom_instance_prompts: + if args.fsdp_text_encoder: + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) + else: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) + prompt_embeds_cache.append(prompt_embeds) + text_ids_cache.append(text_ids) + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + if args.cache_latents: + vae = vae.to("cpu") + del vae + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + text_encoding_pipeline = text_encoding_pipeline.to("cpu") + del text_encoder, tokenizer + free_memory() + + # Scheduler and math around the number of training steps. + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes + if args.max_train_steps is None: + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-flux2-klein-lora" + args_cp = vars(args).copy() + args_cp["text_encoder_out_layers"] = str(args_cp["text_encoder_out_layers"]) + accelerator.init_trackers(tracker_name, config=args_cp) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + prompts = batch["prompts"] + + with accelerator.accumulate(models_to_accumulate): + if train_dataset.custom_instance_prompts: + prompt_embeds = prompt_embeds_cache[step] + text_ids = text_ids_cache[step] + else: + num_repeat_elements = len(prompts) + prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) + text_ids = text_ids.repeat(num_repeat_elements, 1, 1) + + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].mode() + else: + with offload_models(vae, device=accelerator.device, offload=args.offload): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.mode() + + model_input = Flux2KleinPipeline._patchify_latents(model_input) + model_input = (model_input - latents_bn_mean) / latents_bn_std + + model_input_ids = Flux2KleinPipeline._prepare_latent_ids(model_input).to(device=model_input.device) + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # [B, C, H, W] -> [B, H*W, C] + packed_noisy_model_input = Flux2KleinPipeline._pack_latents(noisy_model_input) + + # handle guidance + if transformer.config.guidance_embeds: + guidance = torch.full([1], args.guidance_scale, device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) + else: + guidance = None + + # Predict the noise residual + model_pred = transformer( + hidden_states=packed_noisy_model_input, # (B, image_seq_len, C) + timestep=timesteps / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=model_input_ids, # B, image_seq_len, 4 + return_dict=False, + )[0] + model_pred = model_pred[:, : packed_noisy_model_input.size(1) :] + + model_pred = Flux2KleinPipeline._unpack_latents_with_ids(model_pred, model_input_ids) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = torch.mean( + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process or is_fsdp: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # create pipeline + pipeline = Flux2KleinPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_embeddings, + epoch=epoch, + torch_dtype=weight_dtype, + ) + + del pipeline + free_memory() + + # Save the lora layers + accelerator.wait_for_everyone() + + if is_fsdp: + transformer = unwrap_model(transformer) + state_dict = accelerator.get_state_dict(transformer) + if accelerator.is_main_process: + modules_to_save = {} + if is_fsdp: + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + state_dict = { + k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + else: + state_dict = { + k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + + transformer_lora_layers = get_peft_model_state_dict( + transformer, + state_dict=state_dict, + ) + transformer_lora_layers = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers.items() + } + + else: + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + modules_to_save["transformer"] = transformer + + Flux2KleinPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + **_collate_lora_metadata(modules_to_save), + ) + + images = [] + run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt) + should_run_final_inference = not args.skip_final_inference and run_validation + if should_run_final_inference: + pipeline = Flux2KleinPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_embeddings, + epoch=epoch, + is_final_validation=True, + torch_dtype=weight_dtype, + ) + images = None + del pipeline + free_memory() + + validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt + quant_training = None + if args.do_fp8_training: + quant_training = "FP8 TorchAO" + elif args.bnb_quantization_config_path: + quant_training = "BitsandBytes" + save_model_card( + (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + instance_prompt=args.instance_prompt, + validation_prompt=validation_prompt, + repo_folder=args.output_dir, + quant_training=quant_training, + ) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py new file mode 100644 index 000000000000..7ebfe4d43ff7 --- /dev/null +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -0,0 +1,1890 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "diffusers @ git+https://github.com/huggingface/diffusers.git", +# "torch>=2.0.0", +# "accelerate>=0.31.0", +# "transformers>=4.41.2", +# "ftfy", +# "tensorboard", +# "Jinja2", +# "peft>=0.11.1", +# "sentencepiece", +# "torchvision", +# "datasets", +# "bitsandbytes", +# "prodigyopt", +# ] +# /// + +import argparse +import copy +import itertools +import json +import logging +import math +import os +import random +import shutil +from contextlib import nullcontext +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torch.utils.data.sampler import BatchSampler +from torchvision import transforms +from torchvision.transforms import functional as TF +from tqdm.auto import tqdm +from transformers import Qwen3ForCausalLM, Qwen2TokenizerFast + +import diffusers +from diffusers import ( + AutoencoderKLFlux2, + BitsAndBytesConfig, + FlowMatchEulerDiscreteScheduler, + Flux2KleinPipeline, + Flux2Transformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor +from diffusers.training_utils import ( + _collate_lora_metadata, + _to_cpu_contiguous, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + find_nearest_bucket, + free_memory, + get_fsdp_kwargs_from_accelerator, + offload_models, + parse_buckets_string, + wrap_with_fsdp, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, + load_image, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available +from diffusers.utils.torch_utils import is_compiled_module + + +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.37.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, + fp8_training=False, +): + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + + model_description = f""" +# Flux.2 [Klein] DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md). + +FP8 training? {fp8_training} + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.2", torch_dtype=torch.bfloat16).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "flux2", + "flux2-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, +): + args.num_validation_images = args.num_validation_images if args.num_validation_images else 1 + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(dtype=torch_dtype) + pipeline.enable_model_cpu_offload() + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None + autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() + + images = [] + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + image=pipeline_args["image"], + prompt_embeds=pipeline_args["prompt_embeds"], + negative_prompt_embeds=pipeline_args["negative_prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + free_memory() + + return images + + +def module_filter_fn(mod: torch.nn.Module, fqn: str): + # don't convert the output module + if fqn == "proj_out": + return False + # don't convert linear modules with weight dimensions not divisible by 16 + if isinstance(mod, torch.nn.Linear): + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + return True + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--bnb_quantization_config_path", + type=str, + default=None, + help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.", + ) + parser.add_argument( + "--do_fp8_training", + action="store_true", + help="if we are doing FP8 training.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--cond_image_column", + type=str, + default=None, + help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=False, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=512, + help="Maximum sequence length to use with with the T5 text encoder", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + help="path to an image that is used during validation as the condition image to verify that the model is learning.", + ) + parser.add_argument( + "--skip_final_inference", + default=False, + action="store_true", + help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.", + ) + parser.add_argument( + "--final_validation_prompt", + type=str, + default=None, + help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", + ) + parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") + + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--aspect_ratio_buckets", + type=str, + default=None, + help=( + "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. " + "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'" + "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored." + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only' + ), + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) + parser.add_argument( + "--offload", + action="store_true", + help="Whether to offload the VAE and the text encoder to CPU when they are not used.", + ) + + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.cond_image_column is None: + raise ValueError( + "you must provide --cond_image_column for image-to-image training. Otherwise please see Flux2 text-to-image training example." + ) + else: + assert args.image_column is not None + assert args.caption_column is not None + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + size=1024, + repeats=1, + center_crop=False, + buckets=None, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + + self.buckets = buckets + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.cond_image_column is not None and args.cond_image_column not in column_names: + raise ValueError( + f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + cond_images = None + cond_image_column = args.cond_image_column + if cond_image_column is not None: + cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))] + assert len(instance_images) == len(cond_images) + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + self.cond_images = [] + for i, img in enumerate(instance_images): + self.instance_images.extend(itertools.repeat(img, repeats)) + if args.dataset_name is not None and cond_images is not None: + self.cond_images.extend(itertools.repeat(cond_images[i], repeats)) + + self.pixel_values = [] + self.cond_pixel_values = [] + for i, image in enumerate(self.instance_images): + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + dest_image = None + if self.cond_images: # todo: take care of max area for buckets + dest_image = self.cond_images[i] + image_width, image_height = dest_image.size + if image_width * image_height > 1024 * 1024: + dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024) + image_width, image_height = dest_image.size + + multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp! + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + image_processor = Flux2ImageProcessor() + dest_image = image_processor.preprocess( + dest_image, height=image_height, width=image_width, resize_mode="crop" + ) + # Convert back to PIL + dest_image = dest_image.squeeze(0) + if dest_image.min() < 0: + dest_image = (dest_image + 1) / 2 + dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu() + + if dest_image.shape[0] == 1: + # Gray scale image + dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode="L") + else: + # RGB scale image: (C, H, W) -> (H, W, C) + dest_image = TF.to_pil_image(dest_image) + + dest_image = exif_transpose(dest_image) + if not dest_image.mode == "RGB": + dest_image = dest_image.convert("RGB") + + width, height = image.size + + # Find the closest bucket + bucket_idx = find_nearest_bucket(height, width, self.buckets) + target_height, target_width = self.buckets[bucket_idx] + self.size = (target_height, target_width) + + # based on the bucket assignment, define the transformations + image, dest_image = self.paired_transform( + image, + dest_image=dest_image, + size=self.size, + center_crop=args.center_crop, + random_flip=args.random_flip, + ) + self.pixel_values.append((image, bucket_idx)) + if dest_image is not None: + self.cond_pixel_values.append((dest_image, bucket_idx)) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + example["bucket_idx"] = bucket_idx + if self.cond_pixel_values: + dest_image, _ = self.cond_pixel_values[index % self.num_instance_images] + example["cond_images"] = dest_image + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # custom prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + return example + + def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False): + # 1. Resize (deterministic) + resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + image = resize(image) + if dest_image is not None: + dest_image = resize(dest_image) + + # 2. Crop: either center or SAME random crop + if center_crop: + crop = transforms.CenterCrop(size) + image = crop(image) + if dest_image is not None: + dest_image = crop(dest_image) + else: + # get_params returns (i, j, h, w) + i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size) + image = TF.crop(image, i, j, h, w) + if dest_image is not None: + dest_image = TF.crop(dest_image, i, j, h, w) + + # 3. Random horizontal flip with the SAME coin flip + if random_flip: + do_flip = random.random() < 0.5 + if do_flip: + image = TF.hflip(image) + if dest_image is not None: + dest_image = TF.hflip(dest_image) + + # 4. ToTensor + Normalize (deterministic) + to_tensor = transforms.ToTensor() + normalize = transforms.Normalize([0.5], [0.5]) + image = normalize(to_tensor(image)) + if dest_image is not None: + dest_image = normalize(to_tensor(dest_image)) + + return (image, dest_image) if dest_image is not None else (image, None) + + +def collate_fn(examples): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + if any("cond_images" in example for example in examples): + cond_pixel_values = [example["cond_images"] for example in examples] + cond_pixel_values = torch.stack(cond_pixel_values) + cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float() + batch.update({"cond_pixel_values": cond_pixel_values}) + return batch + + +class BucketBatchSampler(BatchSampler): + def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last)) + + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + # Group indices by bucket + self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] + for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values): + self.bucket_indices[bucket_idx].append(idx) + + self.sampler_len = 0 + self.batches = [] + + # Pre-generate batches for each bucket + for indices_in_bucket in self.bucket_indices: + # Shuffle indices within the bucket + random.shuffle(indices_in_bucket) + # Create batches + for i in range(0, len(indices_in_bucket), self.batch_size): + batch = indices_in_bucket[i : i + self.batch_size] + if len(batch) < self.batch_size and self.drop_last: + continue # Skip partial batch if drop_last is True + self.batches.append(batch) + self.sampler_len += 1 # Count the number of batches + + def __iter__(self): + # Shuffle the order of the batches each epoch + random.shuffle(self.batches) + for batch in self.batches: + yield batch + + def __len__(self): + return self.sampler_len + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `hf auth login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + if args.do_fp8_training: + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load the tokenizers + tokenizer = Qwen2TokenizerFast.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler", + revision=args.revision, + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + vae = AutoencoderKLFlux2.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( + accelerator.device + ) + + quantization_config = None + if args.bnb_quantization_config_path is not None: + with open(args.bnb_quantization_config_path, "r") as f: + config_kwargs = json.load(f) + if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]: + config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype + quantization_config = BitsAndBytesConfig(**config_kwargs) + + transformer = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + quantization_config=quantization_config, + torch_dtype=weight_dtype, + ) + if args.bnb_quantization_config_path is not None: + transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) + + text_encoder = Qwen3ForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder.requires_grad_(False) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + transformer.set_attention_backend("_native_npu") + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ") + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype} + # flux vae is stable in bf16 so load it in weight_dtype to reduce memory + vae.to(**to_kwargs) + # we never offload the transformer to CPU, so we can just use the accelerator device + transformer_to_kwargs = ( + {"device": accelerator.device} + if args.bnb_quantization_config_path is not None + else {"device": accelerator.device, "dtype": weight_dtype} + ) + + is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None + if not is_fsdp: + transformer.to(**transformer_to_kwargs) + + if args.do_fp8_training: + convert_to_float8_training( + transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) + ) + + text_encoder.to(**to_kwargs) + # Initialize a text encoding pipeline and keep it to CPU for now. + text_encoding_pipeline = Flux2KleinPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=None, + transformer=None, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=None, + revision=args.revision, + ) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + + # now we will add new LoRA weights the transformer layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + transformer_cls = type(unwrap_model(transformer)) + + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None + + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") + + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(model) if is_fsdp else None + + # 3) Only main process materializes the LoRA state dict + transformer_lora_layers_to_save = None + if accelerator.is_main_process: + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model, + **peft_kwargs, + ) + + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + Flux2KleinPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + if not is_fsdp: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = Flux2KleinPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + if args.aspect_ratio_buckets is not None: + buckets = parse_buckets_string(args.aspect_ratio_buckets) + else: + buckets = [(args.resolution, args.resolution)] + logger.info(f"Using parsed aspect ratio buckets: {buckets}") + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + buckets=buckets, + ) + batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=lambda examples: collate_fn(examples), + num_workers=args.dataloader_num_workers, + ) + + def compute_text_embeddings(prompt, text_encoding_pipeline): + with torch.no_grad(): + prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt( + prompt=prompt, max_sequence_length=args.max_sequence_length + ) + return prompt_embeds, text_ids + + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not train_dataset.custom_instance_prompts: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoding_pipeline + ) + + if args.validation_prompt is not None: + validation_image = load_image(args.validation_image).convert("RGB") + validation_kwargs = {"image": validation_image} + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + validation_kwargs["prompt_embeds"], _text_ids = compute_text_embeddings( + args.validation_prompt, text_encoding_pipeline + ) + validation_kwargs["negative_prompt_embeds"], _text_ids = compute_text_embeddings( + "", text_encoding_pipeline + ) + + # Init FSDP for text encoder + if args.fsdp_text_encoder: + fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator) + text_encoder_fsdp = wrap_with_fsdp( + model=text_encoding_pipeline.text_encoder, + device=accelerator.device, + offload=args.offload, + limit_all_gathers=True, + use_orig_params=True, + fsdp_kwargs=fsdp_kwargs, + ) + + text_encoding_pipeline.text_encoder = text_encoder_fsdp + dist.barrier() + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + if not train_dataset.custom_instance_prompts: + prompt_embeds = instance_prompt_hidden_states + text_ids = instance_text_ids + + # if cache_latents is set to True, we encode images to latents and store them. + # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided + # we encode them in advance as well. + precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + if precompute_latents: + prompt_embeds_cache = [] + text_ids_cache = [] + latents_cache = [] + cond_latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + if args.cache_latents: + with offload_models(vae, device=accelerator.device, offload=args.offload): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=vae.dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + batch["cond_pixel_values"] = batch["cond_pixel_values"].to( + accelerator.device, non_blocking=True, dtype=vae.dtype + ) + cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist) + if train_dataset.custom_instance_prompts: + if args.fsdp_text_encoder: + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) + else: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) + prompt_embeds_cache.append(prompt_embeds) + text_ids_cache.append(text_ids) + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + if args.cache_latents: + vae = vae.to("cpu") + del vae + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + text_encoding_pipeline = text_encoding_pipeline.to("cpu") + del text_encoder, tokenizer + free_memory() + + # Scheduler and math around the number of training steps. + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes + if args.max_train_steps is None: + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-flux2-image2img-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + prompts = batch["prompts"] + + with accelerator.accumulate(models_to_accumulate): + if train_dataset.custom_instance_prompts: + prompt_embeds = prompt_embeds_cache[step] + text_ids = text_ids_cache[step] + else: + num_repeat_elements = len(prompts) + prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) + text_ids = text_ids.repeat(num_repeat_elements, 1, 1) + + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].mode() + cond_model_input = cond_latents_cache[step].mode() + else: + with offload_models(vae, device=accelerator.device, offload=args.offload): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype) + + model_input = vae.encode(pixel_values).latent_dist.mode() + cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode() + + model_input = Flux2KleinPipeline._patchify_latents(model_input) + model_input = (model_input - latents_bn_mean) / latents_bn_std + + cond_model_input = Flux2KleinPipeline._patchify_latents(cond_model_input) + cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std + + model_input_ids = Flux2KleinPipeline._prepare_latent_ids(model_input).to(device=model_input.device) + cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])] + cond_model_input_ids = Flux2KleinPipeline._prepare_image_ids(cond_model_input_list).to( + device=cond_model_input.device + ) + cond_model_input_ids = cond_model_input_ids.view( + cond_model_input.shape[0], -1, model_input_ids.shape[-1] + ) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # [B, C, H, W] -> [B, H*W, C] + # concatenate the model inputs with the cond inputs + packed_noisy_model_input = Flux2KleinPipeline._pack_latents(noisy_model_input) + packed_cond_model_input = Flux2KleinPipeline._pack_latents(cond_model_input) + orig_input_shape = packed_noisy_model_input.shape + orig_input_ids_shape = model_input_ids.shape + + # concatenate the model inputs with the cond inputs + packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1) + model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1) + + # handle guidance + if transformer.config.guidance_embeds: + guidance = torch.full([1], args.guidance_scale, device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) + else: + guidance=None + + # Predict the noise residual + model_pred = transformer( + hidden_states=packed_noisy_model_input, # (B, image_seq_len, C) + timestep=timesteps / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=model_input_ids, # B, image_seq_len, 4 + return_dict=False, + )[0] + # pruning the condition information + model_pred = model_pred[:, : orig_input_shape[1], :] + model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :] + + model_pred = Flux2KleinPipeline._unpack_latents_with_ids(model_pred, model_input_ids) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process or is_fsdp: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # create pipeline + pipeline = Flux2KleinPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=None, + tokenizer=None, + transformer=unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_kwargs, + epoch=epoch, + torch_dtype=weight_dtype, + ) + + del pipeline + free_memory() + + # Save the lora layers + accelerator.wait_for_everyone() + + if is_fsdp: + transformer = unwrap_model(transformer) + state_dict = accelerator.get_state_dict(transformer) + if accelerator.is_main_process: + modules_to_save = {} + if is_fsdp: + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + state_dict = { + k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + else: + state_dict = { + k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + + transformer_lora_layers = get_peft_model_state_dict( + transformer, + state_dict=state_dict, + ) + transformer_lora_layers = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers.items() + } + + else: + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + modules_to_save["transformer"] = transformer + + Flux2KleinPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + **_collate_lora_metadata(modules_to_save), + ) + + images = [] + run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt) + should_run_final_inference = not args.skip_final_inference and run_validation + if should_run_final_inference: + pipeline = Flux2KleinPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_kwargs, + epoch=epoch, + is_final_validation=True, + torch_dtype=weight_dtype, + ) + del pipeline + free_memory() + + validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt + save_model_card( + (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + instance_prompt=args.instance_prompt, + validation_prompt=validation_prompt, + repo_folder=args.output_dir, + fp8_training=args.do_fp8_training, + ) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) From 292e9c4f3c61f378be4b0b94e18fe45c2d4d2cf4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 15 Jan 2026 17:18:05 +0100 Subject: [PATCH 12/15] style --- .../train_dreambooth_lora_flux2_klein.py | 22 ++++++++---------- ...ain_dreambooth_lora_flux2_klein_img2img.py | 23 +++++++++---------- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 0141d61d8594..278c25900a3a 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -63,7 +63,7 @@ from torchvision import transforms from torchvision.transforms import functional as TF from tqdm.auto import tqdm -from transformers import Qwen3ForCausalLM, Qwen2TokenizerFast +from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM import diffusers from diffusers import ( @@ -1231,7 +1231,6 @@ def main(args): transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) ) - text_encoder.to(**to_kwargs) # Initialize a text encoding pipeline and keep it to CPU for now. text_encoding_pipeline = Flux2KleinPipeline.from_pretrained( @@ -1473,28 +1472,27 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): ) return prompt_embeds, text_ids - # If no type of tuning is done on the text_encoder and custom instance prompts are NOT # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not train_dataset.custom_instance_prompts: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): - instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings( - args.instance_prompt, text_encoding_pipeline - ) + instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoding_pipeline + ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): - class_prompt_hidden_states, class_text_ids = compute_text_embeddings( - args.class_prompt, text_encoding_pipeline - ) + class_prompt_hidden_states, class_text_ids = compute_text_embeddings( + args.class_prompt, text_encoding_pipeline + ) validation_embeddings = {} if args.validation_prompt is not None: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): - (validation_embeddings["prompt_embeds"], validation_embeddings["text_ids"]) = compute_text_embeddings( - args.validation_prompt, text_encoding_pipeline - ) + (validation_embeddings["prompt_embeds"], validation_embeddings["text_ids"]) = compute_text_embeddings( + args.validation_prompt, text_encoding_pipeline + ) # Init FSDP for text encoder if args.fsdp_text_encoder: diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index 7ebfe4d43ff7..28cbaf8f72e7 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -61,7 +61,7 @@ from torchvision import transforms from torchvision.transforms import functional as TF from tqdm.auto import tqdm -from transformers import Qwen3ForCausalLM, Qwen2TokenizerFast +from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM import diffusers from diffusers import ( @@ -1418,26 +1418,25 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): ) return prompt_embeds, text_ids - # If no type of tuning is done on the text_encoder and custom instance prompts are NOT # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not train_dataset.custom_instance_prompts: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): - instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings( - args.instance_prompt, text_encoding_pipeline - ) + instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoding_pipeline + ) if args.validation_prompt is not None: validation_image = load_image(args.validation_image).convert("RGB") validation_kwargs = {"image": validation_image} with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): - validation_kwargs["prompt_embeds"], _text_ids = compute_text_embeddings( - args.validation_prompt, text_encoding_pipeline - ) - validation_kwargs["negative_prompt_embeds"], _text_ids = compute_text_embeddings( - "", text_encoding_pipeline - ) + validation_kwargs["prompt_embeds"], _text_ids = compute_text_embeddings( + args.validation_prompt, text_encoding_pipeline + ) + validation_kwargs["negative_prompt_embeds"], _text_ids = compute_text_embeddings( + "", text_encoding_pipeline + ) # Init FSDP for text encoder if args.fsdp_text_encoder: @@ -1687,7 +1686,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): guidance = torch.full([1], args.guidance_scale, device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: - guidance=None + guidance = None # Predict the noise residual model_pred = transformer( From e738c478bf4210c0216ebda59bf41df487b4655a Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 15 Jan 2026 07:08:44 -1000 Subject: [PATCH 13/15] Update src/diffusers/pipelines/flux2/pipeline_flux2_klein.py --- src/diffusers/pipelines/flux2/pipeline_flux2_klein.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index 59ae687dd06c..391b1f9fc340 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -46,7 +46,7 @@ >>> import torch >>> from diffusers import Flux2KleinPipeline - >>> pipe = Flux2KleinPipeline.from_pretrained("TODO", torch_dtype=torch.bfloat16) + >>> pipe = Flux2KleinPipeline.from_pretrained("black-forest-labs/FLUX.2-klein-base-9B", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A cat holding a sign that says hello world" >>> # Depending on the variant being used, the pipeline call will slightly vary. From 49df6a6683189eb70ccb7dcc9e21c9a1be55f08f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 15 Jan 2026 18:01:32 +0000 Subject: [PATCH 14/15] Apply style fixes --- src/diffusers/pipelines/flux2/pipeline_flux2_klein.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index 391b1f9fc340..761d6bf2e94c 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -46,7 +46,9 @@ >>> import torch >>> from diffusers import Flux2KleinPipeline - >>> pipe = Flux2KleinPipeline.from_pretrained("black-forest-labs/FLUX.2-klein-base-9B", torch_dtype=torch.bfloat16) + >>> pipe = Flux2KleinPipeline.from_pretrained( + ... "black-forest-labs/FLUX.2-klein-base-9B", torch_dtype=torch.bfloat16 + ... ) >>> pipe.to("cuda") >>> prompt = "A cat holding a sign that says hello world" >>> # Depending on the variant being used, the pipeline call will slightly vary. From ea7736e2caac3fc2179dac2a02c37e562003b6fc Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 15 Jan 2026 19:10:50 +0100 Subject: [PATCH 15/15] auto pipeline --- scripts/convert_flux2_to_diffusers.py | 9 ++++++++- src/diffusers/pipelines/auto_pipeline.py | 5 +++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py index d057fa491d68..a8fa6f87eee1 100644 --- a/scripts/convert_flux2_to_diffusers.py +++ b/scripts/convert_flux2_to_diffusers.py @@ -519,8 +519,15 @@ def main(args): "black-forest-labs/FLUX.1-dev", subfolder="scheduler" ) + if_distilled = "base" not in args.dit_filename + pipe = Flux2Pipeline( - vae=vae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler + vae=vae, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + if_distilled=if_distilled, ) pipe.save_pretrained(args.output_path) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 8a21f91d4e79..5ee44190e23b 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -52,6 +52,7 @@ FluxKontextPipeline, FluxPipeline, ) +from .flux2 import Flux2KleinPipeline, Flux2Pipeline from .glm_image import GlmImagePipeline from .hunyuandit import HunyuanDiTPipeline from .kandinsky import ( @@ -164,6 +165,8 @@ ("flux-control", FluxControlPipeline), ("flux-controlnet", FluxControlNetPipeline), ("flux-kontext", FluxKontextPipeline), + ("flux2-klein", Flux2KleinPipeline), + ("flux2", Flux2Pipeline), ("lumina", LuminaPipeline), ("lumina2", Lumina2Pipeline), ("chroma", ChromaPipeline), @@ -202,6 +205,8 @@ ("flux-controlnet", FluxControlNetImg2ImgPipeline), ("flux-control", FluxControlImg2ImgPipeline), ("flux-kontext", FluxKontextPipeline), + ("flux2-klein", Flux2KleinPipeline), + ("flux2", Flux2Pipeline), ("qwenimage", QwenImageImg2ImgPipeline), ("qwenimage-edit", QwenImageEditPipeline), ("qwenimage-edit-plus", QwenImageEditPlusPipeline),