From f6804da86a680020799ac5e4d1a9fb3effa8a273 Mon Sep 17 00:00:00 2001 From: AliesTaha Date: Wed, 18 Feb 2026 11:56:12 -0800 Subject: [PATCH 1/5] POTQ with QWEN Image Asymmetric Signed-off-by: AliesTaha --- .gitignore | 3 +++ examples/diffusers/quantization/config.py | 22 +++++++++++++++++++ .../diffusers/quantization/models_utils.py | 15 +++++++++++++ examples/diffusers/quantization/quantize.py | 22 ++++++++++++------- examples/diffusers/quantization/utils.py | 8 +++++++ 5 files changed, 62 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index ff350799d..44bf5ef48 100644 --- a/.gitignore +++ b/.gitignore @@ -59,3 +59,6 @@ venv/ **.pickle **.tar.gz **.nemo + +# Ignore experiment run +examples/diffusers/quantization/experiment_run \ No newline at end of file diff --git a/examples/diffusers/quantization/config.py b/examples/diffusers/quantization/config.py index d8d8b198b..2090d7eb2 100644 --- a/examples/diffusers/quantization/config.py +++ b/examples/diffusers/quantization/config.py @@ -64,6 +64,28 @@ "algorithm": "max", } +NVFP4_ASYMMETRIC_CONFIG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + "bias": {-1: None, "type": "static", "method": "mean"}, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + "bias": {-1: None, "type": "dynamic", "method": "mean"}, #bias must be dynamic + }, + "*output_quantizer": {"enable": False}, + "default": {"enable": False}, + }, + "algorithm": "max", +} + NVFP4_FP8_MHA_CONFIG = { "quant_cfg": { "**weight_quantizer": { diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index 9a061622e..8aa0ba6a8 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -29,6 +29,7 @@ filter_func_default, filter_func_flux_dev, filter_func_ltx_video, + filter_func_qwen_image, filter_func_wan_video, ) @@ -46,6 +47,7 @@ class ModelType(str, Enum): LTX2 = "ltx-2" WAN22_T2V_14b = "wan2.2-t2v-14b" WAN22_T2V_5b = "wan2.2-t2v-5b" + QWEN_IMAGE_2512 = "qwen-image-2512" def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: @@ -69,6 +71,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.LTX2: filter_func_ltx_video, ModelType.WAN22_T2V_14b: filter_func_wan_video, ModelType.WAN22_T2V_5b: filter_func_wan_video, + ModelType.QWEN_IMAGE_2512: filter_func_qwen_image, } return filter_func_map.get(model_type, filter_func_default) @@ -86,6 +89,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.LTX2: "Lightricks/LTX-2", ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers", + ModelType.QWEN_IMAGE_2512: "Qwen/Qwen-Image-2512", } MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline] | None] = { @@ -99,6 +103,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.LTX2: None, ModelType.WAN22_T2V_14b: WanPipeline, ModelType.WAN22_T2V_5b: WanPipeline, + ModelType.QWEN_IMAGE_2512: DiffusionPipeline, } # Shared dataset configurations @@ -193,6 +198,16 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ), }, }, + ModelType.QWEN_IMAGE_2512: { + "backbone": "transformer", + "dataset": _SD_PROMPTS_DATASET, + "inference_extra_args": { + "height": 1024, + "width": 1024, + "guidance_scale": 4.0, + "negative_prompt": "低分辨率,低画质,肢体畸形,手指畸形,画面过饱和,蜡像感,人脸无细节,过度光滑,画面具有AI感。构图混乱。文字模糊,扭曲。", + }, + }, ModelType.WAN22_T2V_5b: { **_WAN_BASE_CONFIG, "inference_extra_args": { diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index bfff207af..3396d5d3e 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -26,13 +26,14 @@ FP8_DEFAULT_CONFIG, INT8_DEFAULT_CONFIG, NVFP4_DEFAULT_CONFIG, + NVFP4_ASYMMETRIC_CONFIG, NVFP4_FP8_MHA_CONFIG, reset_set_int8_config, set_quant_config_attr, ) from diffusers import DiffusionPipeline from models_utils import MODEL_DEFAULTS, ModelType, get_model_filter_func, parse_extra_params -from onnx_utils.export import generate_fp8_scales, modelopt_export_sd +# from onnx_utils.export import generate_fp8_scales, modelopt_export_sd from pipeline_manager import PipelineManager from quantize_config import ( CalibrationConfig, @@ -133,7 +134,7 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: if self.model_config.model_type.value.startswith("flux"): quant_config = NVFP4_FP8_MHA_CONFIG else: - quant_config = NVFP4_DEFAULT_CONFIG + quant_config = NVFP4_ASYMMETRIC_CONFIG else: raise NotImplementedError(f"Unknown format {self.config.format}") if self.config.quantize_mha: @@ -228,8 +229,12 @@ def save_checkpoint(self, backbone: torch.nn.Module) -> None: return ckpt_path = self.config.quantized_torch_ckpt_path - ckpt_path.mkdir(parents=True, exist_ok=True) - target_path = ckpt_path / "backbone.pt" + if ckpt_path.suffix == ".pt": + target_path = ckpt_path + target_path.parent.mkdir(parents=True, exist_ok=True) + else: + ckpt_path.mkdir(parents=True, exist_ok=True) + target_path = ckpt_path / "backbone.pt" self.logger.info(f"Saving backbone to {target_path}") mto.save(backbone, str(target_path)) @@ -260,7 +265,8 @@ def export_onnx( self.logger.info( "Detected quantizing conv layers in backbone. Generating FP8 scales..." ) - generate_fp8_scales(backbone) + # TODO: needs a fix, commenting out for now + # generate_fp8_scales(backbone) self.logger.info("Preparing models for export...") pipe.to("cpu") torch.cuda.empty_cache() @@ -269,9 +275,9 @@ def export_onnx( backbone.eval() with torch.no_grad(): self.logger.info("Exporting to ONNX...") - modelopt_export_sd( - backbone, str(self.config.onnx_dir), model_type.value, quant_format.value - ) + # modelopt_export_sd( + # backbone, str(self.config.onnx_dir), model_type.value, quant_format.value + # ) self.logger.info("ONNX export completed successfully") diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index 21fcd87d0..9ae61ca1c 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -44,6 +44,14 @@ def filter_func_default(name: str) -> bool: return pattern.match(name) is not None +def filter_func_qwen_image(name: str) -> bool: + """Qwen-Image filter: disable only the 6 standalone layers outside transformer blocks.""" + pattern = re.compile( + r".*(time_text_embed|img_in|txt_in|norm_out|proj_out).*" + ) + return pattern.match(name) is not None + + def check_conv_and_mha(backbone, if_fp4, quantize_mha): for name, module in backbone.named_modules(): if isinstance(module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)) and if_fp4: From b019cf524abba6578a9673dc36a03bd152c7f75b Mon Sep 17 00:00:00 2001 From: AliesTaha Date: Wed, 18 Feb 2026 12:08:53 -0800 Subject: [PATCH 2/5] Address CodeRabbit review: fix linter suppress and docstring Signed-off-by: AliesTaha Co-authored-by: Cursor --- examples/diffusers/quantization/models_utils.py | 2 +- examples/diffusers/quantization/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index 8aa0ba6a8..1b3a93050 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -205,7 +205,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: "height": 1024, "width": 1024, "guidance_scale": 4.0, - "negative_prompt": "低分辨率,低画质,肢体畸形,手指畸形,画面过饱和,蜡像感,人脸无细节,过度光滑,画面具有AI感。构图混乱。文字模糊,扭曲。", + "negative_prompt": "低分辨率,低画质,肢体畸形,手指畸形,画面过饱和,蜡像感,人脸无细节,过度光滑,画面具有AI感。构图混乱。文字模糊,扭曲。", # noqa: RUF001 }, }, ModelType.WAN22_T2V_5b: { diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index 9ae61ca1c..e2f6a3a3b 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -45,7 +45,7 @@ def filter_func_default(name: str) -> bool: def filter_func_qwen_image(name: str) -> bool: - """Qwen-Image filter: disable only the 6 standalone layers outside transformer blocks.""" + """Qwen-Image filter: disable only the 5 standalone layers outside transformer blocks.""" pattern = re.compile( r".*(time_text_embed|img_in|txt_in|norm_out|proj_out).*" ) From 4b25d2ba2f1ec43728a7fb20048da14a7fb21d39 Mon Sep 17 00:00:00 2001 From: AliesTaha Date: Wed, 18 Feb 2026 12:11:57 -0800 Subject: [PATCH 3/5] Revert docstring to 6 layers: time_text_embed covers two sublayers Signed-off-by: AliesTaha Co-authored-by: Cursor Signed-off-by: AliesTaha Co-authored-by: Cursor --- examples/diffusers/quantization/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index e2f6a3a3b..c922ac223 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -45,7 +45,7 @@ def filter_func_default(name: str) -> bool: def filter_func_qwen_image(name: str) -> bool: - """Qwen-Image filter: disable only the 5 standalone layers outside transformer blocks.""" + """Qwen-Image filter: disable the 5 standalone modules outside transformer blocks (time_text_embed covers 2 sublayers).""" pattern = re.compile( r".*(time_text_embed|img_in|txt_in|norm_out|proj_out).*" ) From 89c409d912af22e421797ce30be53a6d911f38d1 Mon Sep 17 00:00:00 2001 From: AliesTaha Date: Wed, 18 Feb 2026 12:19:29 -0800 Subject: [PATCH 4/5] Move QWEN_IMAGE_2512 to end of MODEL_DEFAULTS for consistency Signed-off-by: AliesTaha Co-authored-by: Cursor --- .../diffusers/quantization/models_utils.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index 1b3a93050..b5f0cae8a 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -198,16 +198,6 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ), }, }, - ModelType.QWEN_IMAGE_2512: { - "backbone": "transformer", - "dataset": _SD_PROMPTS_DATASET, - "inference_extra_args": { - "height": 1024, - "width": 1024, - "guidance_scale": 4.0, - "negative_prompt": "低分辨率,低画质,肢体畸形,手指畸形,画面过饱和,蜡像感,人脸无细节,过度光滑,画面具有AI感。构图混乱。文字模糊,扭曲。", # noqa: RUF001 - }, - }, ModelType.WAN22_T2V_5b: { **_WAN_BASE_CONFIG, "inference_extra_args": { @@ -223,6 +213,16 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ), }, }, + ModelType.QWEN_IMAGE_2512: { + "backbone": "transformer", + "dataset": _SD_PROMPTS_DATASET, + "inference_extra_args": { + "height": 1024, + "width": 1024, + "guidance_scale": 4.0, + "negative_prompt": "低分辨率,低画质,肢体畸形,手指畸形,画面过饱和,蜡像感,人脸无细节,过度光滑,画面具有AI感。构图混乱。文字模糊,扭曲。", # noqa: RUF001 + }, + }, } From ac50439b11c740e9a9acebc46c329382b1f0bf39 Mon Sep 17 00:00:00 2001 From: AliesTaha Date: Sat, 21 Feb 2026 20:54:37 -0800 Subject: [PATCH 5/5] qad extraction --- .../qad/once_setup/encode_train_val.py | 123 +++++++++ .../qad/once_setup/teacher_latents.py | 257 ++++++++++++++++++ examples/diffusers/quantization/quantize.py | 2 +- 3 files changed, 381 insertions(+), 1 deletion(-) create mode 100644 examples/diffusers/quantization/qad/once_setup/encode_train_val.py create mode 100644 examples/diffusers/quantization/qad/once_setup/teacher_latents.py diff --git a/examples/diffusers/quantization/qad/once_setup/encode_train_val.py b/examples/diffusers/quantization/qad/once_setup/encode_train_val.py new file mode 100644 index 000000000..3c5c5b0a8 --- /dev/null +++ b/examples/diffusers/quantization/qad/once_setup/encode_train_val.py @@ -0,0 +1,123 @@ +import torch +from diffusers import DiffusionPipeline +from datasets import load_dataset +import time + +def log(msg): + """Print timestamped log message.""" + print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) + +def main(): + DEVICE = "cuda" + + teacher_pipe = DiffusionPipeline.from_pretrained( + "Qwen/Qwen-Image-2512", + torch_dtype=torch.bfloat16, + ) + TRAIN_SAMPLES = 40_000 + NUM_VAL_ENCODE = 100 + PROMPT_TEMPLATE_DROP_IDX = 34 # system prompt + TOKENIZER_MAX_LENGTH = 1024 + PROMPT_TEMPLATE = ( + "<|im_start|>system\n" + "Describe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:" + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + ) + TOKENIZER_MAX_LENGTH = 1024 + + tokenizer = teacher_pipe.tokenizer + text_encoder = teacher_pipe.text_encoder.to(DEVICE).eval() + + #load dataset, split, and encode training and validation prompts + ds = load_dataset("Gustavosta/Stable-Diffusion-Prompts") + all_prompts = ds["train"]["Prompt"] + total_prompts = len(all_prompts) + log(f"Total prompts in dataset: {total_prompts:,}") + + train_prompts = all_prompts[:TRAIN_SAMPLES] + val_prompts = all_prompts[TRAIN_SAMPLES:TRAIN_SAMPLES + NUM_VAL_ENCODE] + + batch_size = 8 + train_embeds = [] + val_embeds = [] + + for i in range(0, len(train_prompts), batch_size): + batch_prompts = train_prompts[i : i + batch_size] + txt = [PROMPT_TEMPLATE.format(p) for p in batch_prompts] + + txt_tokens = tokenizer( + txt, + max_length=TOKENIZER_MAX_LENGTH + PROMPT_TEMPLATE_DROP_IDX, + padding=True, + truncation=True, + return_tensors="pt", + ).to(DEVICE) + + with torch.no_grad(): + encoder_out = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + + hidden_states = encoder_out.hidden_states[-1] + mask = txt_tokens.attention_mask + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_hidden = torch.split(selected, valid_lengths.tolist(), dim=0) + # Drop the system prompt tokens (first drop_idx tokens) + split_hidden = [e[PROMPT_TEMPLATE_DROP_IDX:].cpu() for e in split_hidden] + train_embeds.extend(split_hidden) + + for i in range(0, len(val_prompts), batch_size): + batch_prompts = val_prompts[i : i + batch_size] + txt = [PROMPT_TEMPLATE.format(p) for p in batch_prompts] + + txt_tokens = tokenizer( + txt, + max_length=TOKENIZER_MAX_LENGTH + PROMPT_TEMPLATE_DROP_IDX, + padding=True, + truncation=True, + return_tensors="pt", + ).to(DEVICE) + + with torch.no_grad(): + encoder_out = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_out.hidden_states[-1] + mask = txt_tokens.attention_mask + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_hidden = torch.split(selected, valid_lengths.tolist(), dim=0) + # Drop the system prompt tokens (first drop_idx tokens) + split_hidden = [e[PROMPT_TEMPLATE_DROP_IDX:].cpu() for e in split_hidden] + val_embeds.extend(split_hidden) + + torch.save({i: e for i, e in enumerate(train_embeds)}, "train_embeds.pt") + torch.save({i: e for i, e in enumerate(val_embeds)}, "val_embeds.pt") + +if __name__ == "__main__": + main() + +# === Truncation analysis (raw prompts, no template) === +# Train (70000 prompts): +# min=2, max=441, mean=61.2 +# Truncated (raw tokens > 1024): 0 / 70000 (0.00%) +# Val (100 prompts): +# min=11, max=124, mean=60.1 +# Truncated (raw tokens > 1024): 0 / 100 (0.00%) + +# Length distribution (train): +# ( 0, 128]: 68976 (98.5%) +# ( 128, 256]: 933 (1.3%) +# ( 256, 512]: 91 (0.1%) +# ( 512, 768]: 0 (0.0%) +# ( 768, 1024]: 0 (0.0%) +# (1024, 1280]: 0 (0.0%) +# (1280, 1752]: 0 (0.0%) \ No newline at end of file diff --git a/examples/diffusers/quantization/qad/once_setup/teacher_latents.py b/examples/diffusers/quantization/qad/once_setup/teacher_latents.py new file mode 100644 index 000000000..3a440efaf --- /dev/null +++ b/examples/diffusers/quantization/qad/once_setup/teacher_latents.py @@ -0,0 +1,257 @@ +"""Generate teacher (input, target) latent pairs for QAD distillation. + +For each prompt, assigns a denoising step k (round-robin across NUM_STEPS), +runs the teacher transformer through steps 0..k, and saves: + + input_latent = noisy latent BEFORE step k (student input at train time) + target_latent = denoised latent AFTER step k (ground truth for student) + +Output format (saved as a dict): + input_latents : [N, 4096, 64] bf16 + target_latents : [N, 4096, 64] bf16 + step_indices : [N] int (which denoising step, 0..29) + timestep_values : [N] float (the actual timestep passed to model) + noise_seeds : [N] int (for reproducibility) + +Usage: + python teacher_latents.py +""" + +import os +import time +from typing import MutableSequence + +import numpy as np +import torch +from diffusers import DiffusionPipeline +from tqdm import tqdm + +# ─── CLI args for multi-GPU sharding ───────────────────────────────────────── +import argparse +_p = argparse.ArgumentParser() +_p.add_argument("--gpu", type=int, default=5) +_p.add_argument("--start", type=int, default=0, help="first sample index (inclusive)") +_p.add_argument("--end", type=int, default=-1, help="last sample index (exclusive), -1=all") +_p.add_argument("--tag", type=str, default="", help="output filename suffix") +_args = _p.parse_args() + +# ─── Config ────────────────────────────────────────────────────────────────── +DEVICE = f"cuda:{_args.gpu}" +NUM_STEPS = 30 +BATCH_SIZE = 30 +SAVE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "weights") + +PIXEL_H, PIXEL_W = 1024, 1024 +VAE_SCALE_FACTOR = 8 +PATCH_SIZE = 2 +IMG_H = PIXEL_H // VAE_SCALE_FACTOR // PATCH_SIZE # 64 +IMG_W = PIXEL_W // VAE_SCALE_FACTOR // PATCH_SIZE # 64 +SEQ_LEN = IMG_H * IMG_W # 4096 +IN_CHANNELS = 64 + +# img_shapes tells the transformer the spatial layout for RoPE computation. +# For 1024x1024: one frame, 64x64 patchified grid. +IMG_SHAPES_SINGLE = [(1, IMG_H, IMG_W)] + + +def log(msg): + print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) + + +def calculate_mu(seq_len, base_seq=256, max_seq=8192, base_shift=0.5, max_shift=0.9): + #returns, by default, 0.693548, for 1024x1024 + #m=y2-y1/x2-x1 + #b=y1-m*x1 + #return y=mx+b + y2,y1 = max_shift, base_shift + x, x2,x1 = seq_len, max_seq, base_seq + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return m * x + b + + +# ─── Load model ────────────────────────────────────────────────────────────── +log("Loading teacher pipeline...") +pipe = DiffusionPipeline.from_pretrained( + "Qwen/Qwen-Image-2512", torch_dtype=torch.bfloat16, +) +transformer = pipe.transformer.to(DEVICE).eval() +scheduler = pipe.scheduler + +log(f"Transformer: {transformer.__class__.__name__}, " + f"params={sum(p.numel() for p in transformer.parameters()) / 1e9:.2f}B") + +# ─── Set up denoising schedule (matches SGLang QwenImagePipeline) ──────────── +raw_sigmas = np.linspace(1.0, 1.0 / NUM_STEPS, NUM_STEPS) +mu = calculate_mu(SEQ_LEN) +scheduler.set_timesteps(sigmas=raw_sigmas.tolist(), mu=mu, device=DEVICE) +timesteps = scheduler.timesteps.clone() +sigmas = scheduler.sigmas.clone() + +log(f"Schedule: {NUM_STEPS} steps, mu={mu:.4f}") +log(f"Timesteps: [{', '.join(f'{t:.1f}' for t in timesteps[:5].tolist())} ... " + f"{', '.join(f'{t:.1f}' for t in timesteps[-3:].tolist())}]") + + +# ─── Load prompt embeddings ────────────────────────────────────────────────── +# Format: dict {int_index: tensor [seq_len_i, 3584]} with variable-length prompts. +# We store them as a list and pad per-batch at runtime to avoid wasting memory. +def load_embed_list(data): + """Convert saved format into a list of [seq_len_i, dim] tensors.""" + if isinstance(data, dict) and isinstance(next(iter(data.keys())), int): + n = len(data) + return [data[i] for i in range(n)] + if isinstance(data, list): + return data + raise ValueError(f"Unknown embed format: {type(data)}, keys={list(data.keys())[:5]}") + + +def pad_batch(embed_list, max_seq=1024): + """Pad variable-length embeds into a batch with attention masks.""" + dim = embed_list[0].shape[-1] + lengths = [min(e.shape[0], max_seq) for e in embed_list] + max_len = max(lengths) + B = len(embed_list) + padded = torch.zeros(B, max_len, dim, dtype=embed_list[0].dtype) + mask = torch.zeros(B, max_len, dtype=torch.long) + for i, e in enumerate(embed_list): + L = lengths[i] + padded[i, :L] = e[:L] + mask[i, :L] = 1 + return padded, mask, lengths + + +# ─── Core: generate teacher latent pairs ───────────────────────────────────── +@torch.no_grad() +def generate_teacher_latents(embed_list, split_name): + N = len(embed_list) + log(f"[{split_name}] {N} samples, batch_size={BATCH_SIZE}, " + f"~{N // NUM_STEPS} samples/step") + + all_inputs = [] + all_targets = [] + all_step_idx = [] + all_ts_val = [] + all_seeds = [] + + for batch_start in tqdm(range(0, N, BATCH_SIZE), desc=split_name): + batch_end = min(batch_start + BATCH_SIZE, N) + B = batch_end - batch_start + + step_assignments = [(batch_start + j) % NUM_STEPS for j in range(B)] + max_step_needed = max(step_assignments) + + batch_embeds = embed_list[batch_start:batch_end] + prompt_embeds, prompt_mask, txt_lengths = pad_batch(batch_embeds) + prompt_embeds = prompt_embeds.to(device=DEVICE, dtype=torch.bfloat16) + prompt_mask = prompt_mask.to(device=DEVICE) + txt_seq_lens = txt_lengths + img_shapes = [IMG_SHAPES_SINGLE] * B + + # deterministic noise per sample (seed = global sample index) + g = torch.Generator(device="cpu") + noise = torch.stack([ + torch.randn(SEQ_LEN, IN_CHANNELS, generator=g.manual_seed(batch_start + j), + dtype=torch.bfloat16) + for j in range(B) + ]).to(DEVICE) + + latents = noise.clone() + saved_input = [None] * B + saved_target = [None] * B + + # Reset scheduler step counter for this fresh trajectory + scheduler.set_begin_index(0) + + for step_idx in range(max_step_needed + 1): + # Before the transformer runs: snapshot input for assigned samples + for j in range(B): + if step_assignments[j] == step_idx: + saved_input[j] = latents[j].cpu().clone() + + t = timesteps[step_idx] + # Pipeline divides by 1000 before passing to transformer + t_normed = (t / 1000).expand(B).to(latents.dtype) + + noise_pred = transformer( + hidden_states=latents, + timestep=t_normed, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_mask, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + return_dict=False, + )[0] + + latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # After scheduler step: snapshot target for assigned samples + for j in range(B): + if step_assignments[j] == step_idx: + saved_target[j] = latents[j].cpu().clone() + + for j in range(B): + k = step_assignments[j] + all_inputs.append(saved_input[j]) + all_targets.append(saved_target[j]) + all_step_idx.append(k) + all_ts_val.append(timesteps[k].item()) + all_seeds.append(batch_start + j) + + result = { + "input_latents": torch.stack(all_inputs), + "target_latents": torch.stack(all_targets), + "step_indices": torch.tensor(all_step_idx, dtype=torch.long), + "timestep_values": torch.tensor(all_ts_val, dtype=torch.float32), + "noise_seeds": torch.tensor(all_seeds, dtype=torch.long), + } + dist = torch.bincount(result["step_indices"], minlength=NUM_STEPS) + log(f"[{split_name}] Done — {result['input_latents'].shape}, " + f"step distribution min={dist.min().item()} max={dist.max().item()}") + return result + + +# ─── Run ───────────────────────────────────────────────────────────────────── +if __name__ == "__main__": + os.makedirs(SAVE_DIR, exist_ok=True) + + for split, embed_file in [("train", "train_embeds.pt"), ("val", "val_embeds.pt")]: + path = os.path.join(SAVE_DIR, embed_file) + if not os.path.exists(path): + log(f"[{split}] {path} not found, skipping") + continue + + log(f"[{split}] Loading {path}...") + data = torch.load(path, weights_only=False) + embed_list = load_embed_list(data) + + # Apply shard range + total = len(embed_list) + s, e = _args.start, (_args.end if _args.end > 0 else total) + e = min(e, total) + embed_list = embed_list[s:e] + log(f"[{split}] shard [{s}:{e}) = {len(embed_list)} of {total} prompts") + + if len(embed_list) == 0: + log(f"[{split}] nothing in this shard range, skipping") + continue + + result = generate_teacher_latents(embed_list, split) + + # Global sample indices (so noise seeds stay consistent across shards) + result["noise_seeds"] = result["noise_seeds"] + s + + suffix = f"_{_args.tag}" if _args.tag else "" + out_path = os.path.join(SAVE_DIR, f"teacher_{split}_latents{suffix}.pt") + torch.save(result, out_path) + log(f"[{split}] Saved → {out_path}") + + log("All done!") + +# get scheduler +# calculate_mu +# pass mu to scheduler, let it adjust timesteps and sigmas based on that +# but apparently it doesn't +# we have to do so manually, adjust timesteps by passing sigmas and mus +# but why do i have to compute sigma myself? why can't the scheduler do it? +# then i \ No newline at end of file diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 3396d5d3e..c2965bd30 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -134,7 +134,7 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: if self.model_config.model_type.value.startswith("flux"): quant_config = NVFP4_FP8_MHA_CONFIG else: - quant_config = NVFP4_ASYMMETRIC_CONFIG + quant_config = NVFP4_DEFAULT_CONFIG else: raise NotImplementedError(f"Unknown format {self.config.format}") if self.config.quantize_mha: