Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-Bridge
Submodule Megatron-Bridge updated 298 files
187 changes: 187 additions & 0 deletions dfm/src/megatron/model/wan/conversion/wan_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from diffusers import WanTransformer3DModel
from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
from megatron.bridge.models.conversion.param_mapping import (
AutoMapping,
KVMapping,
QKVMapping,
ReplicatedMapping,
)
from megatron.bridge.models.conversion.utils import get_module_and_param_from_name

from dfm.src.megatron.model.wan.conversion.wan_hf_pretrained import PreTrainedWAN
from dfm.src.megatron.model.wan.wan_model import WanModel
from dfm.src.megatron.model.wan.wan_provider import WanModelProvider


@MegatronModelBridge.register_bridge(source=WanTransformer3DModel, target=WanModel)
class WanBridge(MegatronModelBridge):
"""
Megatron Bridge for WAN model.

As a user you would not use this bridge directly, but through `AutoBridge`.

Example:
>>> from megatron.bridge import AutoBridge
>>> bridge = AutoBridge.from_hf_pretrained("WAN-3D-1.3B-v1")
>>> provider = bridge.to_megatron_provider()
"""

def provider_bridge(self, hf_pretrained: PreTrainedWAN) -> WanModelProvider:
hf_config = hf_pretrained.config

cls = WanModelProvider

provider = cls(
num_layers=hf_config.num_layers,
hidden_size=hf_config.num_attention_heads * hf_config.attention_head_dim,
kv_channels=hf_config.attention_head_dim,
num_query_groups=hf_config.num_attention_heads,
crossattn_emb_size=hf_config.num_attention_heads * hf_config.attention_head_dim,
ffn_hidden_size=hf_config.ffn_dim,
num_attention_heads=hf_config.num_attention_heads,
in_channels=hf_config.in_channels,
out_channels=hf_config.out_channels,
text_dim=hf_config.text_dim,
patch_spatial=hf_config.patch_size[1],
patch_temporal=hf_config.patch_size[0],
layernorm_epsilon=hf_config.eps,
hidden_dropout=0,
attention_dropout=0,
use_cpu_initialization=True,
freq_dim=hf_config.freq_dim,
bf16=False,
params_dtype=torch.float32,
)

return provider

def mapping_registry(self) -> MegatronMappingRegistry:
"""Return MegatronMappingRegistry containing parameter mappings from HF to Megatron format.

Returns:
MegatronMappingRegistry: Registry of parameter mappings
"""
# Dictionary maps HF parameter names -> Megatron parameter names
# Supports wildcard (*) patterns for layer-specific parameters
param_mappings = {
"scale_shift_table": "head.modulation",
"patch_embedding.weight": "patch_embedding.weight",
"patch_embedding.bias": "patch_embedding.bias",
"condition_embedder.time_embedder.linear_1.weight": "time_embedder.linear_1.weight",
"condition_embedder.time_embedder.linear_1.bias": "time_embedder.linear_1.bias",
"condition_embedder.time_embedder.linear_2.weight": "time_embedder.linear_2.weight",
"condition_embedder.time_embedder.linear_2.bias": "time_embedder.linear_2.bias",
"condition_embedder.time_proj.weight": "time_proj.weight",
"condition_embedder.time_proj.bias": "time_proj.bias",
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
"blocks.*.scale_shift_table": "decoder.layers.*.adaLN.modulation",
"blocks.*.attn1.to_out.0.weight": "decoder.layers.*.full_self_attention.linear_proj.weight",
"blocks.*.attn1.to_out.0.bias": "decoder.layers.*.full_self_attention.linear_proj.bias",
"blocks.*.attn1.norm_q.weight": "decoder.layers.*.full_self_attention.q_layernorm.weight",
"blocks.*.attn1.norm_k.weight": "decoder.layers.*.full_self_attention.k_layernorm.weight",
"blocks.*.attn2.to_q.weight": "decoder.layers.*.cross_attention.linear_q.weight",
"blocks.*.attn2.to_q.bias": "decoder.layers.*.cross_attention.linear_q.bias",
"blocks.*.attn2.to_out.0.weight": "decoder.layers.*.cross_attention.linear_proj.weight",
"blocks.*.attn2.to_out.0.bias": "decoder.layers.*.cross_attention.linear_proj.bias",
"blocks.*.attn2.norm_q.weight": "decoder.layers.*.cross_attention.q_layernorm.weight",
"blocks.*.attn2.norm_k.weight": "decoder.layers.*.cross_attention.k_layernorm.weight",
"blocks.*.norm2.weight": "decoder.layers.*.norm3.weight",
"blocks.*.norm2.bias": "decoder.layers.*.norm3.bias",
"blocks.*.ffn.net.0.proj.weight": "decoder.layers.*.mlp.linear_fc1.weight",
"blocks.*.ffn.net.0.proj.bias": "decoder.layers.*.mlp.linear_fc1.bias",
"blocks.*.ffn.net.2.weight": "decoder.layers.*.mlp.linear_fc2.weight",
"blocks.*.ffn.net.2.bias": "decoder.layers.*.mlp.linear_fc2.bias",
"proj_out.weight": "head.head.weight",
"proj_out.bias": "head.head.bias",
}

# Custom WAN mapping to safely handle replicated params whose owning module
# does not expose a top-level `.weight` (e.g., Head.modulation)
class _ReplicatedByParamNameMapping(ReplicatedMapping):
def hf_to_megatron(self, hf_weights, megatron_module):
normalized_param = self._normalize_expert_param_name(self.megatron_param)
_, target_param = get_module_and_param_from_name(megatron_module, normalized_param)

target_device = target_param.device
target_dtype = target_param.dtype

hf_weights = hf_weights.to(device=target_device, dtype=target_dtype)
if self.tp_size == 1:
return hf_weights

if target_device.type == "cuda" and torch.cuda.is_available():
if target_device.index != torch.cuda.current_device():
hf_weights = hf_weights.to(torch.cuda.current_device())

if self.tp_rank > 0:
hf_weights = torch.empty_like(hf_weights)

return self.broadcast_tensor_to_tp_ranks(hf_weights, src_rank=0)

mapping_list = []
# Convert each dictionary entry to AutoMapping(hf_param, megatron_param)
for hf_param, megatron_param in param_mappings.items():
if hf_param in {"scale_shift_table", "blocks.*.scale_shift_table", "proj_out.weight", "proj_out.bias"}:
# Use WAN-specific replicated mapping that resolves the exact param
mapping_list.append(_ReplicatedByParamNameMapping(hf_param=hf_param, megatron_param=megatron_param))
else:
mapping_list.append(AutoMapping(hf_param=hf_param, megatron_param=megatron_param))

# Adding custom module types for AutoMapping
AutoMapping.register_module_type("Linear", "replicated")
AutoMapping.register_module_type("Conv3d", "replicated")
AutoMapping.register_module_type("WanAdaLN", "replicated")
AutoMapping.register_module_type("Head", "replicated")

# Add special mappings that require parameter concatenation/transformation
mapping_list.extend(
[
# QKV: Combine separate Q, K, V matrices into single QKV matrix
QKVMapping(
q="blocks.*.attn1.to_q.weight",
k="blocks.*.attn1.to_k.weight",
v="blocks.*.attn1.to_v.weight",
megatron_param="decoder.layers.*.full_self_attention.linear_qkv.weight",
),
# QKV bias: Combine separate Q, K, V bias into single QKV bias
QKVMapping(
q="blocks.*.attn1.to_q.bias",
k="blocks.*.attn1.to_k.bias",
v="blocks.*.attn1.to_v.bias",
megatron_param="decoder.layers.*.full_self_attention.linear_qkv.bias",
),
# K, V: Combine separate K, V matrices into single KV matrix
KVMapping(
k="blocks.*.attn2.to_k.weight",
v="blocks.*.attn2.to_v.weight",
megatron_param="decoder.layers.*.cross_attention.linear_kv.weight",
),
# K, V bias: Combine separate K, V bias into single KV bias
KVMapping(
k="blocks.*.attn2.to_k.bias",
v="blocks.*.attn2.to_v.bias",
megatron_param="decoder.layers.*.cross_attention.linear_kv.bias",
),
]
)

return MegatronMappingRegistry(*mapping_list)
120 changes: 120 additions & 0 deletions dfm/src/megatron/model/wan/conversion/wan_hf_pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import shutil
from pathlib import Path
from typing import Union

from diffusers import WanTransformer3DModel
from megatron.bridge.models.hf_pretrained.base import PreTrainedBase
from megatron.bridge.models.hf_pretrained.state import SafeTensorsStateSource, StateDict, StateSource
from transformers import AutoConfig


class WanSafeTensorsStateSource(SafeTensorsStateSource):
"""
WAN-specific state source that writes exported HF shards under 'transformer/'.
"""

def save_generator(self, generator, output_path, strict: bool = True):
# Ensure shards are written under transformer/
target_dir = Path(output_path) / "transformer"
return super().save_generator(generator, target_dir, strict=strict)


class PreTrainedWAN(PreTrainedBase):
"""
Lightweight pretrained wrapper for Diffusers WAN models.

Provides access to WAN config and state through the common PreTrainedBase API
so bridges can consume `.config` and `.state` uniformly.

NOTE: Due to Wan uses HF's Diffusers library, which has different checkpoint directory structure to HF's Transformer library,
we need a wrapper to load the model weights and config from the correct directory (e.g., ./transformer).
The diffusers's structure includes all components in the diffusion pipeline (VAE, text encoders, etc.).
The actual transformer weights are stored in the ./transformer directory. Hence, we adjust the input and output
path directory accordingly. We also need to override the save_artifacts method to save relevant correct configs
files to the corresponding directory.
"""

def __init__(self, model_name_or_path: Union[str, Path], **kwargs):
self._model_name_or_path = str(model_name_or_path)
super().__init__(**kwargs)

@property
def model_name_or_path(self) -> str:
return self._model_name_or_path

# Model loading is optional for conversion; implemented for completeness
def _load_model(self) -> WanTransformer3DModel:
return WanTransformer3DModel.from_pretrained(self.model_name_or_path)

# Config is required by the WAN bridge
def _load_config(self) -> AutoConfig:
# WanTransformer3DModel returns a config-like object with required fields

print(f"Loading config from {self.model_name_or_path}")

return WanTransformer3DModel.from_pretrained(self.model_name_or_path, subfolder="transformer").config

@property
def state(self) -> StateDict:
"""
WAN-specific StateDict that reads safetensors from the fixed 'transformer/' subfolder.
"""
if getattr(self, "_state_dict_accessor", None) is None:
source: StateSource | None = None
if hasattr(self, "_model") and self._model is not None:
# If model is loaded, use its in-memory state_dict
source = self.model.state_dict()
else:
# Always load from 'transformer/' subfolder for WAN
source = WanSafeTensorsStateSource(Path(self.model_name_or_path) / "transformer")
self._state_dict_accessor = StateDict(source)
return self._state_dict_accessor

def save_artifacts(self, save_directory: Union[str, Path]):
"""
Save WAN artifacts (currently config) alongside exported weights.
Writes transformer/config.json into the destination.
"""
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)

# Ensure transformer subdir exists at destination
dest_transformer = save_path / "transformer"
dest_transformer.mkdir(parents=True, exist_ok=True)

# 1) If source has a config.json under transformer/, copy it
src_config = Path(self.model_name_or_path) / "transformer" / "config.json"
src_index = Path(self.model_name_or_path) / "transformer" / "diffusion_pytorch_model.safetensors.index.json"
if src_config.exists():
shutil.copyfile(src_config, dest_transformer / "config.json")
if src_index.exists():
shutil.copyfile(src_index, dest_transformer / "diffusion_pytorch_model.safetensors.index.json")
return

# 2) Otherwise, try to export config from the HF model instance
try:
model = WanTransformer3DModel.from_pretrained(self.model_name_or_path, subfolder="transformer")
cfg = getattr(model, "config", None)
if cfg is not None:
# Prefer to_dict if available
cfg_dict = cfg.to_dict() if hasattr(cfg, "to_dict") else dict(cfg)
with open(dest_transformer / "config.json", "w") as f:
json.dump(cfg_dict, f, indent=2)
except Exception:
# Best-effort: if config cannot be produced, leave only weights
pass
1 change: 0 additions & 1 deletion dfm/src/megatron/model/wan/wan_layer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ def get_wan_block_with_transformer_engine_spec() -> ModuleSpec:
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear,
# by default, activation_func is openai_gelu, which is equivalent to nn.GELU(approximate='tanh')
linear_fc2=TERowParallelLinear,
),
),
Expand Down
3 changes: 3 additions & 0 deletions dfm/src/megatron/model/wan/wan_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

import logging
from dataclasses import dataclass
from typing import Callable

import torch
import torch.nn.functional as F
from megatron.bridge.models.model_provider import ModelProviderMixin
from megatron.bridge.models.transformer_config import TransformerConfig
from megatron.core import parallel_state
Expand Down Expand Up @@ -44,6 +46,7 @@ class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]):
layernorm_across_heads: bool = True
add_qkv_bias: bool = True
rotary_interleaved: bool = True
activation_func: Callable = F.gelu
hidden_dropout: float = 0
attention_dropout: float = 0
fp16_lm_cross_entropy: bool = False
Expand Down
Loading