Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,6 @@ venv/
**.pickle
**.tar.gz
**.nemo

# Ignore experiment run
examples/diffusers/quantization/experiment_run
22 changes: 22 additions & 0 deletions examples/diffusers/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,28 @@
"algorithm": "max",
}

NVFP4_ASYMMETRIC_CONFIG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
"bias": {-1: None, "type": "static", "method": "mean"},
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
"bias": {-1: None, "type": "dynamic", "method": "mean"}, #bias must be dynamic
},
"*output_quantizer": {"enable": False},
"default": {"enable": False},
},
"algorithm": "max",
}

NVFP4_FP8_MHA_CONFIG = {
"quant_cfg": {
"**weight_quantizer": {
Expand Down
15 changes: 15 additions & 0 deletions examples/diffusers/quantization/models_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
filter_func_default,
filter_func_flux_dev,
filter_func_ltx_video,
filter_func_qwen_image,
filter_func_wan_video,
)

Expand All @@ -46,6 +47,7 @@ class ModelType(str, Enum):
LTX2 = "ltx-2"
WAN22_T2V_14b = "wan2.2-t2v-14b"
WAN22_T2V_5b = "wan2.2-t2v-5b"
QWEN_IMAGE_2512 = "qwen-image-2512"


def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
Expand All @@ -69,6 +71,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.LTX2: filter_func_ltx_video,
ModelType.WAN22_T2V_14b: filter_func_wan_video,
ModelType.WAN22_T2V_5b: filter_func_wan_video,
ModelType.QWEN_IMAGE_2512: filter_func_qwen_image,
}

return filter_func_map.get(model_type, filter_func_default)
Expand All @@ -86,6 +89,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.LTX2: "Lightricks/LTX-2",
ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
ModelType.QWEN_IMAGE_2512: "Qwen/Qwen-Image-2512",
}

MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline] | None] = {
Expand All @@ -99,6 +103,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.LTX2: None,
ModelType.WAN22_T2V_14b: WanPipeline,
ModelType.WAN22_T2V_5b: WanPipeline,
ModelType.QWEN_IMAGE_2512: DiffusionPipeline,
}

# Shared dataset configurations
Expand Down Expand Up @@ -208,6 +213,16 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
),
},
},
ModelType.QWEN_IMAGE_2512: {
"backbone": "transformer",
"dataset": _SD_PROMPTS_DATASET,
"inference_extra_args": {
"height": 1024,
"width": 1024,
"guidance_scale": 4.0,
"negative_prompt": "低分辨率,低画质,肢体畸形,手指畸形,画面过饱和,蜡像感,人脸无细节,过度光滑,画面具有AI感。构图混乱。文字模糊,扭曲。", # noqa: RUF001
},
},
}


Expand Down
123 changes: 123 additions & 0 deletions examples/diffusers/quantization/qad/once_setup/encode_train_val.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import torch
from diffusers import DiffusionPipeline
from datasets import load_dataset
import time

def log(msg):
"""Print timestamped log message."""
print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)

def main():
DEVICE = "cuda"

teacher_pipe = DiffusionPipeline.from_pretrained(
"Qwen/Qwen-Image-2512",
torch_dtype=torch.bfloat16,
)
TRAIN_SAMPLES = 40_000
NUM_VAL_ENCODE = 100
PROMPT_TEMPLATE_DROP_IDX = 34 # system prompt
TOKENIZER_MAX_LENGTH = 1024
PROMPT_TEMPLATE = (
"<|im_start|>system\n"
"Describe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:"
"<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
)
TOKENIZER_MAX_LENGTH = 1024

tokenizer = teacher_pipe.tokenizer
text_encoder = teacher_pipe.text_encoder.to(DEVICE).eval()

#load dataset, split, and encode training and validation prompts
ds = load_dataset("Gustavosta/Stable-Diffusion-Prompts")
all_prompts = ds["train"]["Prompt"]
total_prompts = len(all_prompts)
log(f"Total prompts in dataset: {total_prompts:,}")

train_prompts = all_prompts[:TRAIN_SAMPLES]
val_prompts = all_prompts[TRAIN_SAMPLES:TRAIN_SAMPLES + NUM_VAL_ENCODE]

batch_size = 8
train_embeds = []
val_embeds = []

for i in range(0, len(train_prompts), batch_size):
batch_prompts = train_prompts[i : i + batch_size]
txt = [PROMPT_TEMPLATE.format(p) for p in batch_prompts]

txt_tokens = tokenizer(
txt,
max_length=TOKENIZER_MAX_LENGTH + PROMPT_TEMPLATE_DROP_IDX,
padding=True,
truncation=True,
return_tensors="pt",
).to(DEVICE)

with torch.no_grad():
encoder_out = text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
output_hidden_states=True,
)

hidden_states = encoder_out.hidden_states[-1]
mask = txt_tokens.attention_mask
bool_mask = mask.bool()
valid_lengths = bool_mask.sum(dim=1)
selected = hidden_states[bool_mask]
split_hidden = torch.split(selected, valid_lengths.tolist(), dim=0)
# Drop the system prompt tokens (first drop_idx tokens)
split_hidden = [e[PROMPT_TEMPLATE_DROP_IDX:].cpu() for e in split_hidden]
train_embeds.extend(split_hidden)

for i in range(0, len(val_prompts), batch_size):
batch_prompts = val_prompts[i : i + batch_size]
txt = [PROMPT_TEMPLATE.format(p) for p in batch_prompts]

txt_tokens = tokenizer(
txt,
max_length=TOKENIZER_MAX_LENGTH + PROMPT_TEMPLATE_DROP_IDX,
padding=True,
truncation=True,
return_tensors="pt",
).to(DEVICE)

with torch.no_grad():
encoder_out = text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
output_hidden_states=True,
)
hidden_states = encoder_out.hidden_states[-1]
mask = txt_tokens.attention_mask
bool_mask = mask.bool()
valid_lengths = bool_mask.sum(dim=1)
selected = hidden_states[bool_mask]
split_hidden = torch.split(selected, valid_lengths.tolist(), dim=0)
# Drop the system prompt tokens (first drop_idx tokens)
split_hidden = [e[PROMPT_TEMPLATE_DROP_IDX:].cpu() for e in split_hidden]
val_embeds.extend(split_hidden)

torch.save({i: e for i, e in enumerate(train_embeds)}, "train_embeds.pt")
torch.save({i: e for i, e in enumerate(val_embeds)}, "val_embeds.pt")

if __name__ == "__main__":
main()

# === Truncation analysis (raw prompts, no template) ===
# Train (70000 prompts):
# min=2, max=441, mean=61.2
# Truncated (raw tokens > 1024): 0 / 70000 (0.00%)
# Val (100 prompts):
# min=11, max=124, mean=60.1
# Truncated (raw tokens > 1024): 0 / 100 (0.00%)

# Length distribution (train):
# ( 0, 128]: 68976 (98.5%)
# ( 128, 256]: 933 (1.3%)
# ( 256, 512]: 91 (0.1%)
# ( 512, 768]: 0 (0.0%)
# ( 768, 1024]: 0 (0.0%)
# (1024, 1280]: 0 (0.0%)
# (1280, 1752]: 0 (0.0%)
Loading