Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-Bridge
Submodule Megatron-Bridge updated 77 files
+1 −5 .github/workflows/cache-hf-model.yml
+1 −2 .github/workflows/cicd-main.yml
+1 −1 3rdparty/Megatron-LM
+1 −1 CONTRIBUTING.md
+0 −1 docs/models/llm/index.md
+0 −183 docs/models/llm/olmoe.md
+0 −1 docs/models/vlm/index.md
+0 −143 docs/models/vlm/qwen2.5-vl.md
+0 −7 docs/training/checkpointing.md
+1 −0 docs/training/peft.md
+0 −674 examples/conversion/compare_text_generation.py
+1 −1 examples/quantization/ptq_generate.py
+0 −48 examples/recipes/qwen3_next/conf/qwen3_next_80b_a3b_finetune_override_example.yaml
+0 −147 examples/recipes/qwen3_next/finetune_qwen3_next_80b_a3b.py
+2 −9 examples/recipes/qwen_vl/finetune_qwen25_vl.py
+31 −16 scripts/performance/argument_parser.py
+12 −22 scripts/performance/configs/deepseek/deepseek_llm_pretrain.py
+8 −8 scripts/performance/configs/gpt_oss/gpt_oss_llm_pretrain.py
+22 −38 scripts/performance/configs/llama3/llama3_llm_pretrain.py
+5 −2 scripts/performance/configs/llama3/workload_base_configs.py
+11 −27 scripts/performance/configs/llama31/llama31_llm_pretrain.py
+6 −6 scripts/performance/configs/nemotronh/nemotronh_llm_pretrain.py
+23 −24 scripts/performance/configs/qwen3/qwen3_llm_pretrain.py
+25 −12 scripts/performance/perf_plugins.py
+8 −11 scripts/performance/setup_experiment.py
+6 −9 scripts/performance/utils/executors.py
+25 −34 scripts/performance/utils/helpers.py
+6 −2 scripts/performance/utils/utils.py
+0 −105 src/megatron/bridge/data/iterator_utils.py
+1 −1 src/megatron/bridge/data/loaders.py
+1 −39 src/megatron/bridge/models/conversion/auto_bridge.py
+4 −112 src/megatron/bridge/models/conversion/model_bridge.py
+8 −8 src/megatron/bridge/models/conversion/param_mapping.py
+2 −2 src/megatron/bridge/models/conversion/utils.py
+0 −1 src/megatron/bridge/models/gemma/gemma3_provider.py
+0 −3 src/megatron/bridge/models/gpt_provider.py
+0 −2 src/megatron/bridge/models/mamba/mamba_provider.py
+1 −1 src/megatron/bridge/models/olmoe/olmoe_provider.py
+7 −31 src/megatron/bridge/peft/lora.py
+0 −2 src/megatron/bridge/recipes/gemma/__init__.py
+1 −245 src/megatron/bridge/recipes/gemma/gemma3.py
+5 −5 src/megatron/bridge/recipes/gpt_oss/gpt_oss.py
+0 −22 src/megatron/bridge/recipes/llama/__init__.py
+22 −534 src/megatron/bridge/recipes/llama/llama3.py
+0 −24 src/megatron/bridge/recipes/olmoe/__init__.py
+0 −682 src/megatron/bridge/recipes/olmoe/olmoe_7b.py
+0 −6 src/megatron/bridge/recipes/qwen/__init__.py
+1 −294 src/megatron/bridge/recipes/qwen/qwen3_moe.py
+49 −292 src/megatron/bridge/recipes/qwen/qwen3_next.py
+9 −73 src/megatron/bridge/recipes/qwen_vl/qwen25_vl.py
+21 −37 src/megatron/bridge/training/checkpointing.py
+1 −22 src/megatron/bridge/training/config.py
+4 −44 src/megatron/bridge/training/eval.py
+1 −0 src/megatron/bridge/training/pretrain.py
+0 −6 src/megatron/bridge/training/setup.py
+68 −125 src/megatron/bridge/training/train.py
+38 −154 src/megatron/bridge/training/utils/checkpoint_utils.py
+121 −184 tests/end_to_end_tests/evaluate_recipe_training.py
+0 −28 tests/functional_tests/L2_Launch_post_training_quantization.sh
+0 −1 tests/functional_tests/L2_Launch_quantization.sh
+10 −81 tests/functional_tests/quantization/test_qat_workflow.py
+10 −27 tests/functional_tests/quantization/test_quantization_workflow.py
+0 −1 tests/functional_tests/training/test_finetune_lora.py
+0 −2 tests/functional_tests/training/test_pretrain_resume.py
+1 −1 tests/functional_tests/utils.py
+0 −143 tests/unit_tests/data/test_finetuning.py
+0 −109 tests/unit_tests/data/test_iterator_utils.py
+0 −40 tests/unit_tests/models/test_auto_bridge.py
+0 −119 tests/unit_tests/models/test_model_bridge_lora.py
+14 −262 tests/unit_tests/recipes/test_gemma3_recipes.py
+24 −356 tests/unit_tests/recipes/test_llama_recipes.py
+0 −457 tests/unit_tests/recipes/test_olmoe_recipes.py
+17 −216 tests/unit_tests/recipes/test_qwen_recipes.py
+7 −26 tests/unit_tests/training/test_checkpointing.py
+0 −233 tests/unit_tests/training/test_train.py
+0 −46 tests/unit_tests/training/utils/test_checkpoint_utils.py
+164 −104 uv.lock
74 changes: 72 additions & 2 deletions dfm/src/automodel/_diffusers/auto_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import os
from typing import Any, Dict, Iterable, Optional, Tuple

import torch
import torch.nn as nn
from Automodel.distributed.dfm_parallelizer import WanParallelizationStrategy
from diffusers import DiffusionPipeline
from diffusers import DiffusionPipeline, WanPipeline
from nemo_automodel.components.distributed import parallelizer
from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager
from nemo_automodel.shared.utils import dtype_from_str

from dfm.src.automodel.distributed.dfm_parallelizer import WanParallelizationStrategy


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -154,3 +156,71 @@ def from_pretrained(
parallel_module = manager.parallelize(comp_module)
setattr(pipe, comp_name, parallel_module)
return pipe, created_managers


class NeMoWanPipeline:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def from_pretrained(cls, *args, **kwargs):
return NeMoAutoDiffusionPipeline.from_pretrained(*args, **kwargs)

@classmethod
def from_config(
cls,
model_id,
torch_dtype: torch.dtype = torch.bfloat16,
config: dict = None,
parallel_scheme: Optional[Dict[str, Dict[str, Any]]] = None,
device: Optional[torch.device] = None,
move_to_device: bool = True,
components_to_load: Optional[Iterable[str]] = None,
):
# Load just the config
from diffusers import WanTransformer3DModel

if config is None:
transformer = WanTransformer3DModel.from_pretrained(
model_id,
subfolder="transformer",
torch_dtype=torch.bfloat16,
)

# Get config and reinitialize with random weights
config = copy.deepcopy(transformer.config)
del transformer

# Initialize with random weights
transformer = WanTransformer3DModel.from_config(config)

# Load pipeline with random transformer
pipe = WanPipeline.from_pretrained(
model_id,
transformer=transformer,
torch_dtype=torch_dtype,
)
# Decide device
dev = _choose_device(device)

# Move modules to device/dtype first (helps avoid initial OOM during sharding)
if move_to_device:
for name, module in _iter_pipeline_modules(pipe):
if not components_to_load or name in components_to_load:
logger.info("[INFO] Moving module: %s to device/dtype", name)
_move_module_to_device(module, dev, torch_dtype)

# Use per-component FSDP2Manager init-args to parallelize components
created_managers: Dict[str, FSDP2Manager] = {}
if parallel_scheme is not None:
assert torch.distributed.is_initialized(), "Expect distributed environment to be initialized"
_init_parallelizer()
for comp_name, comp_module in _iter_pipeline_modules(pipe):
manager_args = parallel_scheme.get(comp_name)
if manager_args is None:
continue
manager = FSDP2Manager(**manager_args)
created_managers[comp_name] = manager
parallel_module = manager.parallelize(comp_module)
setattr(pipe, comp_name, parallel_module)
return pipe, created_managers
2 changes: 1 addition & 1 deletion dfm/src/automodel/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from Automodel.datasets.wan21 import (
from dfm.src.automodel.datasets.wan21 import (
MetaFilesDataset,
build_node_parallel_sampler,
build_wan21_dataloader,
Expand Down
32 changes: 22 additions & 10 deletions dfm/src/automodel/flow_matching/training_step_t2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from typing import Dict, Tuple

import torch
from Automodel.flow_matching.time_shift_utils import (

from dfm.src.automodel.flow_matching.time_shift_utils import (
compute_density_for_timestep_sampling,
)

Expand All @@ -28,8 +29,8 @@


def step_fsdp_transformer_t2v(
pipe,
model_map: Dict,
scheduler,
model,
batch,
device,
bf16,
Expand All @@ -40,6 +41,8 @@ def step_fsdp_transformer_t2v(
logit_std: float = 1.0,
flow_shift: float = 3.0,
mix_uniform_ratio: float = 0.1,
sigma_min: float = 0.0, # Default: no clamping (pretrain)
sigma_max: float = 1.0, # Default: no clamping (pretrain)
global_step: int = 0,
) -> Tuple[torch.Tensor, Dict]:
"""
Expand Down Expand Up @@ -74,7 +77,7 @@ def step_fsdp_transformer_t2v(
# Flow Matching Timestep Sampling
# ========================================================================

num_train_timesteps = pipe.scheduler.config.num_train_timesteps
num_train_timesteps = scheduler.config.num_train_timesteps

if use_sigma_noise:
use_uniform = torch.rand(1).item() < mix_uniform_ratio
Expand All @@ -96,12 +99,23 @@ def step_fsdp_transformer_t2v(
# Apply flow shift: σ = shift/(shift + (1/u - 1))
u_clamped = torch.clamp(u, min=1e-5) # Avoid division by zero
sigma = flow_shift / (flow_shift + (1.0 / u_clamped - 1.0))
sigma = torch.clamp(sigma, 0.0, 1.0)

# Clamp sigma (only if not full range [0,1])
# Pretrain uses [0, 1], finetune uses [0.02, 0.55]
if sigma_min > 0.0 or sigma_max < 1.0:
sigma = torch.clamp(sigma, sigma_min, sigma_max)
else:
sigma = torch.clamp(sigma, 0.0, 1.0)

else:
# Simple uniform without shift
u = torch.rand(size=(batch_size,), device=device)
sigma = u

# Clamp sigma (only if not full range [0,1])
if sigma_min > 0.0 or sigma_max < 1.0:
sigma = torch.clamp(u, sigma_min, sigma_max)
else:
sigma = u
sampling_method = "uniform_no_shift"

# ========================================================================
Expand Down Expand Up @@ -186,10 +200,8 @@ def step_fsdp_transformer_t2v(
# Forward Pass
# ========================================================================

fsdp_model = model_map["transformer"]["fsdp_transformer"]

try:
model_pred = fsdp_model(
model_pred = model(
hidden_states=noisy_latents,
timestep=timesteps_for_model,
encoder_hidden_states=text_embeddings,
Expand Down Expand Up @@ -243,7 +255,7 @@ def step_fsdp_transformer_t2v(
logger.info(f"[STEP {global_step}] LOSS DEBUG")
logger.info("=" * 80)
logger.info("[TARGET] Flow matching: v = ε - x_0")
logger.info(f"[PREDICTION] Scheduler type (inference only): {type(pipe.scheduler).__name__}")
logger.info(f"[PREDICTION] Scheduler type (inference only): {type(scheduler).__name__}")
logger.info("")
logger.info(f"[RANGES] Model pred: [{model_pred.min():.4f}, {model_pred.max():.4f}]")
logger.info(f"[RANGES] Target (v): [{target.min():.4f}, {target.max():.4f}]")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
import torch
import torch.distributed as dist
import wandb
from Automodel._diffusers.auto_diffusion_pipeline import NeMoAutoDiffusionPipeline
from Automodel.flow_matching.training_step_t2v import (
step_fsdp_transformer_t2v,
)
from nemo_automodel.components.checkpoint.checkpointing import Checkpointer, CheckpointingConfig
from nemo_automodel.components.loggers.log_utils import setup_logging
from nemo_automodel.components.loggers.wandb_utils import suppress_wandb_log_messages
Expand All @@ -36,68 +32,71 @@
from torch.distributed.fsdp import MixedPrecisionPolicy
from transformers.utils.hub import TRANSFORMERS_CACHE

from dfm.src.automodel._diffusers.auto_diffusion_pipeline import NeMoWanPipeline
from dfm.src.automodel.flow_matching.training_step_t2v import (
step_fsdp_transformer_t2v,
)


def build_model_and_optimizer(
*,
model_id: str,
finetune_mode: bool,
learning_rate: float,
device: torch.device,
bf16_dtype: torch.dtype,
dtype: torch.dtype,
cpu_offload: bool = False,
tp_size: int = 1,
cp_size: int = 1,
pp_size: int = 1,
dp_size: Optional[int] = None,
dp_replicate_size: Optional[int] = None,
use_hf_tp_plan: bool = False,
fsdp_cfg: Dict[str, Any] = {},
optimizer_cfg: Optional[Dict[str, Any]] = None,
) -> tuple[NeMoAutoDiffusionPipeline, dict[str, Dict[str, Any]], torch.optim.Optimizer, Any]:
) -> tuple[NeMoWanPipeline, dict[str, Dict[str, Any]], torch.optim.Optimizer, Any]:
"""Build the WAN 2.1 diffusion model, parallel scheme, and optimizer."""

logging.info("[INFO] Building NeMoAutoDiffusionPipeline with transformer parallel scheme...")
logging.info("[INFO] Building NeMoWanPipeline with transformer parallel scheme...")

if not dist.is_initialized():
logging.info("[WARN] torch.distributed not initialized; proceeding in single-process mode")

world_size = dist.get_world_size() if dist.is_initialized() else 1

if dp_size is None:
denom = max(1, tp_size * cp_size * pp_size)
dp_size = max(1, world_size // denom)
if fsdp_cfg.get("dp_size", None) is None:
denom = max(1, fsdp_cfg.get("tp_size", 1) * fsdp_cfg.get("cp_size", 1) * fsdp_cfg.get("pp_size", 1))
fsdp_cfg.dp_size = max(1, world_size // denom)

manager_args: Dict[str, Any] = {
"dp_size": dp_size,
"dp_replicate_size": dp_replicate_size,
"tp_size": tp_size,
"cp_size": cp_size,
"pp_size": pp_size,
"dp_size": fsdp_cfg.get("dp_size", None),
"dp_replicate_size": fsdp_cfg.get("dp_replicate_size", None),
"tp_size": fsdp_cfg.get("tp_size", 1),
"cp_size": fsdp_cfg.get("cp_size", 1),
"pp_size": fsdp_cfg.get("pp_size", 1),
"backend": "nccl",
"world_size": world_size,
"use_hf_tp_plan": use_hf_tp_plan,
"use_hf_tp_plan": fsdp_cfg.get("use_hf_tp_plan", False),
"activation_checkpointing": True,
"mp_policy": MixedPrecisionPolicy(
param_dtype=bf16_dtype,
reduce_dtype=bf16_dtype,
output_dtype=bf16_dtype,
param_dtype=dtype,
reduce_dtype=dtype,
output_dtype=dtype,
),
}

parallel_scheme = {"transformer": manager_args}

pipe, created_managers = NeMoAutoDiffusionPipeline.from_pretrained(
kwargs = {}
if finetune_mode:
kwargs["load_for_training"] = True
kwargs["low_cpu_mem_usage"] = True
init_fn = NeMoWanPipeline.from_pretrained if finetune_mode else NeMoWanPipeline.from_config

pipe, created_managers = init_fn(
model_id,
torch_dtype=bf16_dtype,
torch_dtype=dtype,
device=device,
parallel_scheme=parallel_scheme,
load_for_training=True,
components_to_load=["transformer"],
**kwargs,
)
fsdp2_manager = created_managers["transformer"]
transformer_module = getattr(pipe, "transformer", None)
if transformer_module is None:
raise RuntimeError("transformer not found in pipeline after parallelization")

model_map: dict[str, Dict[str, Any]] = {"transformer": {"fsdp_transformer": transformer_module}}
transformer_module = pipe.transformer

trainable_params = [p for p in transformer_module.parameters() if p.requires_grad]
if not trainable_params:
Expand All @@ -121,7 +120,7 @@ def build_model_and_optimizer(

logging.info("[INFO] NeMoAutoDiffusion setup complete (pipeline + optimizer)")

return pipe, model_map, optimizer, fsdp2_manager.device_mesh
return pipe, optimizer, getattr(fsdp2_manager, "device_mesh", None)


def build_lr_scheduler(
Expand Down Expand Up @@ -198,36 +197,27 @@ def setup(self):
self.logit_std = fm_cfg.get("logit_std", 1.0)
self.flow_shift = fm_cfg.get("flow_shift", 3.0)
self.mix_uniform_ratio = fm_cfg.get("mix_uniform_ratio", 0.1)
self.sigma_min = fm_cfg.get("sigma_min", 0.0)
self.sigma_max = fm_cfg.get("sigma_max", 1.0)

logging.info(f"[INFO] Flow matching: {'ENABLED' if self.use_sigma_noise else 'DISABLED'}")
if self.use_sigma_noise:
logging.info(f"[INFO] - Timestep sampling: {self.timestep_sampling}")
logging.info(f"[INFO] - Flow shift: {self.flow_shift}")
logging.info(f"[INFO] - Mix uniform ratio: {self.mix_uniform_ratio}")

tp_size = fsdp_cfg.get("tp_size", 1)
cp_size = fsdp_cfg.get("cp_size", 1)
pp_size = fsdp_cfg.get("pp_size", 1)
dp_size = fsdp_cfg.get("dp_size", None)
dp_replicate_size = fsdp_cfg.get("dp_replicate_size", None)
use_hf_tp_plan = fsdp_cfg.get("use_hf_tp_plan", False)

(self.pipe, self.model_map, self.optimizer, self.device_mesh) = build_model_and_optimizer(
(self.pipe, self.optimizer, self.device_mesh) = build_model_and_optimizer(
model_id=self.model_id,
finetune_mode=self.cfg.get("model.mode", "finetune").lower() == "finetune",
learning_rate=self.learning_rate,
device=self.device,
bf16_dtype=self.bf16,
dtype=self.bf16,
cpu_offload=self.cpu_offload,
tp_size=tp_size,
cp_size=cp_size,
pp_size=pp_size,
dp_size=dp_size,
dp_replicate_size=dp_replicate_size,
use_hf_tp_plan=use_hf_tp_plan,
fsdp_cfg=fsdp_cfg,
optimizer_cfg=self.cfg.get("optim.optimizer", {}),
)

self.model = self.model_map["transformer"]["fsdp_transformer"]
self.model = self.pipe.transformer
self.peft_config = None

batch_cfg = self.cfg.get("batch", {})
Expand Down Expand Up @@ -283,6 +273,9 @@ def setup(self):
raise RuntimeError("Training dataloader is empty; cannot proceed with training")

# Derive DP size consistent with model parallel config
tp_size = fsdp_cfg.get("tp_size", 1)
cp_size = fsdp_cfg.get("cp_size", 1)
pp_size = fsdp_cfg.get("pp_size", 1)
denom = max(1, tp_size * cp_size * pp_size)
self.dp_size = fsdp_cfg.get("dp_size", None)
if self.dp_size is None:
Expand Down Expand Up @@ -356,8 +349,8 @@ def run_train_validation_loop(self):
for micro_batch in batch_group:
try:
loss, _ = step_fsdp_transformer_t2v(
pipe=self.pipe,
model_map=self.model_map,
scheduler=self.pipe.scheduler,
model=self.model,
batch=micro_batch,
device=self.device,
bf16=self.bf16,
Expand All @@ -367,6 +360,8 @@ def run_train_validation_loop(self):
logit_std=self.logit_std,
flow_shift=self.flow_shift,
mix_uniform_ratio=self.mix_uniform_ratio,
sigma_min=self.sigma_min,
sigma_max=self.sigma_max,
global_step=global_step,
)
except Exception as exc:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

from __future__ import annotations

from Automodel.recipes.finetune import TrainWan21DiffusionRecipe
from nemo_automodel.components.config._arg_parser import parse_args_and_load_config

from dfm.src.automodel.recipes.train import TrainWan21DiffusionRecipe


def main(default_config_path="/opt/DFM/dfm/examples/Automodel/finetune/wan2_1_t2v_flow.yaml"):
cfg = parse_args_and_load_config(default_config_path)
Expand Down
Loading