-
Notifications
You must be signed in to change notification settings - Fork 282
Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2) #911
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1363878
6e922ae
09ede35
7e74e91
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -34,8 +34,10 @@ | |||||||||||||||||||||||||||||||||||
| import diffusers | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| from .diffusers_utils import ( | ||||||||||||||||||||||||||||||||||||
| DIFFUSION_MERGE_FUNCTIONS, | ||||||||||||||||||||||||||||||||||||
| generate_diffusion_dummy_forward_fn, | ||||||||||||||||||||||||||||||||||||
| get_diffusion_components, | ||||||||||||||||||||||||||||||||||||
| get_diffusion_model_type, | ||||||||||||||||||||||||||||||||||||
| get_qkv_group_key, | ||||||||||||||||||||||||||||||||||||
| hide_quantizers_from_state_dict, | ||||||||||||||||||||||||||||||||||||
| infer_dtype_from_model, | ||||||||||||||||||||||||||||||||||||
|
|
@@ -112,19 +114,62 @@ def _is_enabled_quantizer(quantizer): | |||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def _save_component_state_dict_safetensors( | ||||||||||||||||||||||||||||||||||||
| component: nn.Module, component_export_dir: Path | ||||||||||||||||||||||||||||||||||||
| component: nn.Module, | ||||||||||||||||||||||||||||||||||||
| component_export_dir: Path, | ||||||||||||||||||||||||||||||||||||
| merged_base_safetensor_path: str | None = None, | ||||||||||||||||||||||||||||||||||||
| hf_quant_config: dict | None = None, | ||||||||||||||||||||||||||||||||||||
| model_type: str | None = None, | ||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||
| """Save component state dict as safetensors with optional base checkpoint merge. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||
| component: The nn.Module to save. | ||||||||||||||||||||||||||||||||||||
| component_export_dir: Directory to save model.safetensors and config.json. | ||||||||||||||||||||||||||||||||||||
| merged_base_safetensor_path: If provided, merge with non-transformer components | ||||||||||||||||||||||||||||||||||||
| from this base safetensors file. | ||||||||||||||||||||||||||||||||||||
| hf_quant_config: If provided, embed quantization config in safetensors metadata | ||||||||||||||||||||||||||||||||||||
| and per-layer _quantization_metadata for ComfyUI. | ||||||||||||||||||||||||||||||||||||
| model_type: Key into ``DIFFUSION_MERGE_FUNCTIONS`` for the model-specific merge. | ||||||||||||||||||||||||||||||||||||
| Required when ``merged_base_safetensor_path`` is not None. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| cpu_state_dict = {k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()} | ||||||||||||||||||||||||||||||||||||
| save_file(cpu_state_dict, str(component_export_dir / "model.safetensors")) | ||||||||||||||||||||||||||||||||||||
| metadata: dict[str, str] = {} | ||||||||||||||||||||||||||||||||||||
| metadata_full: dict[str, str] = {} | ||||||||||||||||||||||||||||||||||||
| if merged_base_safetensor_path is not None and model_type is not None: | ||||||||||||||||||||||||||||||||||||
| merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type] | ||||||||||||||||||||||||||||||||||||
| cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path) | ||||||||||||||||||||||||||||||||||||
| if hf_quant_config is not None: | ||||||||||||||||||||||||||||||||||||
| metadata_full["quantization_config"] = json.dumps(hf_quant_config) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Build per-layer _quantization_metadata for ComfyUI | ||||||||||||||||||||||||||||||||||||
| quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() | ||||||||||||||||||||||||||||||||||||
| layer_metadata = {} | ||||||||||||||||||||||||||||||||||||
| for k in cpu_state_dict: | ||||||||||||||||||||||||||||||||||||
| if k.endswith((".weight_scale", ".weight_scale_2")): | ||||||||||||||||||||||||||||||||||||
| layer_name = k.rsplit(".", 1)[0] | ||||||||||||||||||||||||||||||||||||
| if layer_name.endswith(".weight"): | ||||||||||||||||||||||||||||||||||||
| layer_name = layer_name.rsplit(".", 1)[0] | ||||||||||||||||||||||||||||||||||||
| if layer_name not in layer_metadata: | ||||||||||||||||||||||||||||||||||||
| layer_metadata[layer_name] = {"format": quant_algo} | ||||||||||||||||||||||||||||||||||||
| metadata_full["_quantization_metadata"] = json.dumps( | ||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||
| "format_version": "1.0", | ||||||||||||||||||||||||||||||||||||
| "layers": layer_metadata, | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| metadata["_export_format"] = "safetensors_state_dict" | ||||||||||||||||||||||||||||||||||||
| metadata["_class_name"] = type(component).__name__ | ||||||||||||||||||||||||||||||||||||
| metadata_full.update(metadata) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| save_file( | ||||||||||||||||||||||||||||||||||||
| cpu_state_dict, | ||||||||||||||||||||||||||||||||||||
| str(component_export_dir / "model.safetensors"), | ||||||||||||||||||||||||||||||||||||
| metadata=metadata_full if merged_base_safetensor_path is not None else None, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| with open(component_export_dir / "config.json", "w") as f: | ||||||||||||||||||||||||||||||||||||
| json.dump( | ||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||
| "_class_name": type(component).__name__, | ||||||||||||||||||||||||||||||||||||
| "_export_format": "safetensors_state_dict", | ||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||
| f, | ||||||||||||||||||||||||||||||||||||
| indent=4, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| json.dump(metadata, f, indent=4) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def _collect_shared_input_modules( | ||||||||||||||||||||||||||||||||||||
|
|
@@ -807,6 +852,7 @@ def _export_diffusers_checkpoint( | |||||||||||||||||||||||||||||||||||
| dtype: torch.dtype | None, | ||||||||||||||||||||||||||||||||||||
| export_dir: Path, | ||||||||||||||||||||||||||||||||||||
| components: list[str] | None, | ||||||||||||||||||||||||||||||||||||
| merged_base_safetensor_path: str | None = None, | ||||||||||||||||||||||||||||||||||||
| max_shard_size: int | str = "10GB", | ||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||
| """Internal: Export diffusion(-like) model/pipeline checkpoint. | ||||||||||||||||||||||||||||||||||||
|
|
@@ -821,6 +867,8 @@ def _export_diffusers_checkpoint( | |||||||||||||||||||||||||||||||||||
| export_dir: The directory to save the exported checkpoint. | ||||||||||||||||||||||||||||||||||||
| components: Optional list of component names to export. Only used for pipelines. | ||||||||||||||||||||||||||||||||||||
| If None, all components are exported. | ||||||||||||||||||||||||||||||||||||
| merged_base_safetensor_path: If provided, merge the exported transformer with | ||||||||||||||||||||||||||||||||||||
| non-transformer components from this base safetensors file. | ||||||||||||||||||||||||||||||||||||
| max_shard_size: Maximum size of each shard file. If the model exceeds this size, | ||||||||||||||||||||||||||||||||||||
| it will be sharded into multiple files and a .safetensors.index.json will be | ||||||||||||||||||||||||||||||||||||
| created. Use smaller values like "5GB" or "2GB" to force sharding. | ||||||||||||||||||||||||||||||||||||
|
|
@@ -834,6 +882,9 @@ def _export_diffusers_checkpoint( | |||||||||||||||||||||||||||||||||||
| warnings.warn("No exportable components found in the model.") | ||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Resolve model type once (only needed when merging with a base checkpoint) | ||||||||||||||||||||||||||||||||||||
| model_type = get_diffusion_model_type(pipe) if merged_base_safetensor_path else None | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
Comment on lines
+885
to
+887
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If a user passes 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||
| # Separate nn.Module components for quantization-aware export | ||||||||||||||||||||||||||||||||||||
| module_components = { | ||||||||||||||||||||||||||||||||||||
| name: comp for name, comp in all_components.items() if isinstance(comp, nn.Module) | ||||||||||||||||||||||||||||||||||||
|
|
@@ -879,6 +930,7 @@ def _export_diffusers_checkpoint( | |||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Step 5: Build quantization config | ||||||||||||||||||||||||||||||||||||
| quant_config = get_quant_config(component, is_modelopt_qlora=False) | ||||||||||||||||||||||||||||||||||||
| hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Step 6: Save the component | ||||||||||||||||||||||||||||||||||||
| # - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter | ||||||||||||||||||||||||||||||||||||
|
|
@@ -888,12 +940,15 @@ def _export_diffusers_checkpoint( | |||||||||||||||||||||||||||||||||||
| component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| with hide_quantizers_from_state_dict(component): | ||||||||||||||||||||||||||||||||||||
| _save_component_state_dict_safetensors(component, component_export_dir) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| _save_component_state_dict_safetensors( | ||||||||||||||||||||||||||||||||||||
| component, | ||||||||||||||||||||||||||||||||||||
| component_export_dir, | ||||||||||||||||||||||||||||||||||||
| merged_base_safetensor_path, | ||||||||||||||||||||||||||||||||||||
| hf_quant_config, | ||||||||||||||||||||||||||||||||||||
| model_type, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| # Step 7: Update config.json with quantization info | ||||||||||||||||||||||||||||||||||||
| if quant_config is not None: | ||||||||||||||||||||||||||||||||||||
| hf_quant_config = convert_hf_quant_config_format(quant_config) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if hf_quant_config is not None: | ||||||||||||||||||||||||||||||||||||
| config_path = component_export_dir / "config.json" | ||||||||||||||||||||||||||||||||||||
| if config_path.exists(): | ||||||||||||||||||||||||||||||||||||
| with open(config_path) as file: | ||||||||||||||||||||||||||||||||||||
|
|
@@ -905,7 +960,12 @@ def _export_diffusers_checkpoint( | |||||||||||||||||||||||||||||||||||
| elif hasattr(component, "save_pretrained"): | ||||||||||||||||||||||||||||||||||||
| component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| _save_component_state_dict_safetensors(component, component_export_dir) | ||||||||||||||||||||||||||||||||||||
| _save_component_state_dict_safetensors( | ||||||||||||||||||||||||||||||||||||
| component, | ||||||||||||||||||||||||||||||||||||
| component_export_dir, | ||||||||||||||||||||||||||||||||||||
| merged_base_safetensor_path, | ||||||||||||||||||||||||||||||||||||
| model_type=model_type, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
960
to
+968
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Non-quantized components also receive When a non-quantized component falls through to In the current LTX-2 flow there's only one component, so this is harmless. But for future model types with multiple components, this would produce incorrect merged checkpoints for non-quantized components. Consider guarding by only passing Proposed safeguard else:
_save_component_state_dict_safetensors(
component,
component_export_dir,
- merged_base_safetensor_path,
- model_type=model_type,
)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| print(f" Saved to: {component_export_dir}") | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
@@ -985,6 +1045,7 @@ def export_hf_checkpoint( | |||||||||||||||||||||||||||||||||||
| save_modelopt_state: bool = False, | ||||||||||||||||||||||||||||||||||||
| components: list[str] | None = None, | ||||||||||||||||||||||||||||||||||||
| extra_state_dict: dict[str, torch.Tensor] | None = None, | ||||||||||||||||||||||||||||||||||||
| merged_base_safetensor_path: str | None = None, | ||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||
| """Export quantized HuggingFace model checkpoint (transformers or diffusers). | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
@@ -1002,6 +1063,9 @@ def export_hf_checkpoint( | |||||||||||||||||||||||||||||||||||
| components: Only used for diffusers pipelines. Optional list of component names | ||||||||||||||||||||||||||||||||||||
| to export. If None, all quantized components are exported. | ||||||||||||||||||||||||||||||||||||
| extra_state_dict: Extra state dictionary to add to the exported model. | ||||||||||||||||||||||||||||||||||||
| merged_base_safetensor_path: If provided, merge the exported diffusion transformer | ||||||||||||||||||||||||||||||||||||
| with non-transformer components (VAE, vocoder, etc.) from this base safetensors | ||||||||||||||||||||||||||||||||||||
| file. Only used for diffusion model exports (e.g., LTX-2). | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| export_dir = Path(export_dir) | ||||||||||||||||||||||||||||||||||||
| export_dir.mkdir(parents=True, exist_ok=True) | ||||||||||||||||||||||||||||||||||||
|
|
@@ -1010,7 +1074,9 @@ def export_hf_checkpoint( | |||||||||||||||||||||||||||||||||||
| if HAS_DIFFUSERS: | ||||||||||||||||||||||||||||||||||||
| is_diffusers_obj = is_diffusers_object(model) | ||||||||||||||||||||||||||||||||||||
| if is_diffusers_obj: | ||||||||||||||||||||||||||||||||||||
| _export_diffusers_checkpoint(model, dtype, export_dir, components) | ||||||||||||||||||||||||||||||||||||
| _export_diffusers_checkpoint( | ||||||||||||||||||||||||||||||||||||
| model, dtype, export_dir, components, merged_base_safetensor_path | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Transformers model export | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Metadata is discarded for non-merge exports.
On line 168, metadata is only passed to
save_filewhenmerged_base_safetensor_path is not None. For non-merge exports through this function, the_export_formatand_class_namemetadata (lines 161-162) are computed but thrown away —save_fileis called withmetadata=None.If metadata should always be attached (even for non-merge exports), pass it unconditionally:
Proposed fix
save_file( cpu_state_dict, str(component_export_dir / "model.safetensors"), - metadata=metadata_full if merged_base_safetensor_path is not None else None, + metadata=metadata_full, )📝 Committable suggestion
🤖 Prompt for AI Agents