Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
dae57f9
adding tests
Nov 14, 2025
7f2ed10
Merge remote-tracking branch 'origin/main' into huvu/mcore_wan_unit_t…
Nov 14, 2025
93f94bc
ruff lint
Nov 14, 2025
244ee74
ruff lint
Nov 14, 2025
550f283
ruff lint
Nov 14, 2025
926a951
Explicit mcore path override to use Megatron-Bridge's pinned submodul…
pablo-garay Nov 14, 2025
ecdef9e
Update Megatron-Bridge submodule to latest main with correct Megatron…
pablo-garay Nov 14, 2025
5038518
Add Mcore WAN pretrain mock test to CI/CD
pablo-garay Nov 14, 2025
c746d18
lintfix
pablo-garay Nov 14, 2025
697201d
Fix slow Docker build from Megatron-LM source
pablo-garay Nov 15, 2025
013ca6d
ci: Update gpu runners to use self-hosted-nemo (#48)
chtruong814 Nov 16, 2025
f240ccd
Reapply "Revert GHA changes"
pablo-garay Nov 16, 2025
0964c62
update path per request
pablo-garay Nov 16, 2025
d08b5af
lintfix
pablo-garay Nov 16, 2025
eebe731
update CONTRIBUTING.md
pablo-garay Nov 16, 2025
3dadf02
Merge branch 'main' into pablo-garay/mbridge-test-init
pablo-garay Nov 16, 2025
6685a54
lintfix
pablo-garay Nov 16, 2025
04d802e
Merge branch 'pablo-garay/mbridge-test-init' of https://github.com/NV…
pablo-garay Nov 16, 2025
1b8c2d1
Merge remote-tracking branch 'origin/main' into huvu/mcore_wan_unit_t…
Nov 16, 2025
77e37cd
Merge remote-tracking branch 'origin/pablo-garay/mbridge-test-init' i…
Nov 16, 2025
a5a109a
adding v run --group megatron-bridge
Nov 16, 2025
f43aea3
update test
Nov 16, 2025
d2d983f
ruff lint
Nov 16, 2025
23909e8
update main
Nov 17, 2025
166e809
restore Dockerfile.ci
Nov 17, 2025
c1bde61
update .github/workflows/cicd-main.yml
Nov 17, 2025
2de3124
Merge branch 'main' into huvu/mcore_wan_unit_tests
huvunvidia Nov 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ jobs:
fail-fast: false
matrix:
include:
- script: L2_Functional_Tests_GPU
runner: self-hosted-nemo
timeout: 30
- script: L2_Mcore_Mock_Tests_GPU
runner: self-hosted-nemo
timeout: 30
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,13 @@ def forward_pp_step(
"""

pp_world_size = parallel_state.get_pipeline_model_parallel_world_size()
is_pp_first = parallel_state.is_pipeline_first_stage(ignore_virtual=True)
is_pp_last = parallel_state.is_pipeline_last_stage(ignore_virtual=True)

# PP=1: no pipeline parallelism
# PP=1: no pipeline parallelism (avoid touching PP groups which may be uninitialized in unit tests)
if pp_world_size == 1:
noise_pred_pp = self.model(latent_model_input, grid_sizes=grid_sizes, t=timestep, **arg_c)
return noise_pred_pp
# For PP>1, safe to query stage information
is_pp_first = parallel_state.is_pipeline_first_stage(ignore_virtual=True)
is_pp_last = parallel_state.is_pipeline_last_stage(ignore_virtual=True)

# PP>1: pipeline parallelism
hidden_size = self.model.config.hidden_size
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/L0_Unit_Tests_CPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@

# Hide GPU from PyTorch by setting CUDA_VISIBLE_DEVICES to empty
# This makes torch.cuda.is_available() return False
CUDA_VISIBLE_DEVICES="" uv run coverage run -a --data-file=/opt/DFM/.coverage --source=/opt/DFM/ -m pytest tests/unit_tests -m "not pleasefixme" --with_downloads
CUDA_VISIBLE_DEVICES="" uv run --group megatron-bridge coverage run -a --data-file=/opt/DFM/.coverage --source=/opt/DFM/ -m pytest tests/unit_tests -m "not pleasefixme" --with_downloads
2 changes: 1 addition & 1 deletion tests/unit_tests/L0_Unit_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
CUDA_VISIBLE_DEVICES="0,1" uv run coverage run -a --data-file=/opt/DFM/.coverage --source=/opt/DFM/ -m pytest tests/unit_tests -m "not pleasefixme" --with_downloads
CUDA_VISIBLE_DEVICES="0,1" uv run --group megatron-bridge coverage run -a --data-file=/opt/DFM/.coverage --source=/opt/DFM/ -m pytest tests/unit_tests -m "not pleasefixme" --with_downloads
67 changes: 67 additions & 0 deletions tests/unit_tests/megatron/data/wan/test_wan_energon_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dfm.src.megatron.data.wan import wan_energon_datamodule as wan_dm_mod
from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder


class _FakeDiffusionDataModule:
def __init__(
self,
*,
path: str,
seq_length: int,
packing_buffer_size: int,
task_encoder,
micro_batch_size: int,
global_batch_size: int,
num_workers: int,
):
self.path = path
self.seq_length = seq_length
self.packing_buffer_size = packing_buffer_size
self.task_encoder = task_encoder
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.num_workers = num_workers

# mimic API used by WanDataModuleConfig.build_datasets
def train_dataloader(self):
return "train"


def test_wan_datamodule_config_initialization(monkeypatch):
# Patch the symbol used inside wan_energon_datamodule module
monkeypatch.setattr(wan_dm_mod, "DiffusionDataModule", _FakeDiffusionDataModule)

cfg = wan_dm_mod.WanDataModuleConfig(
path="",
seq_length=128,
task_encoder_seq_length=128,
packing_buffer_size=4,
micro_batch_size=2,
global_batch_size=8,
num_workers=0,
)

# __post_init__ should construct a dataset with WanTaskEncoder and propagate seq_length
assert isinstance(cfg.dataset, _FakeDiffusionDataModule)
assert cfg.sequence_length == cfg.dataset.seq_length == 128
assert isinstance(cfg.dataset.task_encoder, WanTaskEncoder)
assert cfg.dataset.task_encoder.seq_length == 128
assert cfg.dataset.task_encoder.packing_buffer_size == 4

# build_datasets should return train loader thrice
train, val, test = cfg.build_datasets(context=None)
assert train == "train" and val == "train" and test == "train"
65 changes: 65 additions & 0 deletions tests/unit_tests/megatron/data/wan/test_wan_mock_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch.utils.data import DataLoader

from dfm.src.megatron.data.wan.wan_mock_datamodule import WanMockDataModuleConfig


def test_wan_mock_datamodule_build_and_batch_shapes():
cfg = WanMockDataModuleConfig(
path="",
seq_length=128,
packing_buffer_size=2,
micro_batch_size=2,
global_batch_size=8,
num_workers=0,
# Use small shapes for a light-weight test run
F_latents=4,
H_latents=8,
W_latents=6,
patch_spatial=2,
patch_temporal=1,
number_packed_samples=2,
context_seq_len=16,
context_embeddings_dim=64,
)
train_dl, val_dl, test_dl = cfg.build_datasets(_context=None)
assert isinstance(train_dl, DataLoader)
assert train_dl is val_dl and val_dl is test_dl

batch = next(iter(train_dl))
expected_keys = {
"video_latents",
"context_embeddings",
"loss_mask",
"seq_len_q",
"seq_len_q_padded",
"seq_len_kv",
"seq_len_kv_padded",
"grid_sizes",
"video_metadata",
}
assert expected_keys.issubset(set(batch.keys()))

# Basic sanity checks on shapes/dtypes
assert batch["video_latents"].dim() == 3 and batch["video_latents"].shape[1] == 1
assert batch["context_embeddings"].dim() == 3 and batch["context_embeddings"].shape[1] == 1
assert batch["loss_mask"].dim() == 2 and batch["loss_mask"].shape[1] == 1
assert batch["seq_len_q"].dtype == torch.int32
assert batch["seq_len_q_padded"].dtype == torch.int32
assert batch["seq_len_kv"].dtype == torch.int32
assert batch["seq_len_kv_padded"].dtype == torch.int32
assert batch["grid_sizes"].dtype == torch.int32
154 changes: 154 additions & 0 deletions tests/unit_tests/megatron/data/wan/test_wan_taskencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder, cook, parallel_state


def test_cook_extracts_expected_fields():
sample = {
"__key__": "k",
"__restore_key__": "rk",
"__subflavors__": [],
"json": {"meta": 1},
"pth": torch.randn(1, 2, 2, 2),
"pickle": torch.randn(3, 4),
"unused": 123,
}
out = cook(sample)
assert "json" in out and out["json"] is sample["json"]
assert "pth" in out and torch.equal(out["pth"], sample["pth"])
assert "pickle" in out and torch.equal(out["pickle"], sample["pickle"])
# ensure basic keys from the sample are preserved by cook via basic_sample_keys()
assert out["__key__"] == sample["__key__"]
assert out["__restore_key__"] == sample["__restore_key__"]
assert out["__subflavors__"] == sample["__subflavors__"]


def test_encode_sample_no_context_parallel(monkeypatch):
# Ensure CP world size is 1 to avoid extra padding branch
monkeypatch.setattr(parallel_state, "get_context_parallel_world_size", lambda: 1, raising=False)
# Ensure seeded wrapper has an active worker config
from megatron.energon.task_encoder.base import WorkerConfig

class _FakeWorkerCfg:
def worker_seed(self):
return 123

active_worker_sample_index = 0

monkeypatch.setattr(WorkerConfig, "active_worker_config", _FakeWorkerCfg(), raising=False)

# Construct a minimal, consistent sample
c = 8
F_latents, H_latents, W_latents = 4, 8, 6
patch_temporal, patch_spatial = 1, 2
# video latent before patchify has shape [c, F_latents, H_latents, W_latents]
# where grid sizes (patch counts) are (F_latents // pF, H_latents // pH, W_latents // pW)
video_latent = torch.randn(c, F_latents, H_latents, W_latents)
context_len, context_dim = 256, 64
context_embeddings = torch.randn(context_len, context_dim)
sample = {
"__key__": "k",
"__restore_key__": "rk",
"__subflavors__": [],
"json": {"meta": 1},
"pth": video_latent,
"pickle": context_embeddings,
}

enc = WanTaskEncoder(
seq_length=1024, patch_temporal=patch_temporal, patch_spatial=patch_spatial, packing_buffer_size=None
)
out = enc.encode_sample(sample)

# Grid / patches
F_patches = F_latents // patch_temporal
H_patches = H_latents // patch_spatial
W_patches = W_latents // patch_spatial
num_patches = F_patches * H_patches * W_patches
patch_vec_dim = c * patch_temporal * patch_spatial * patch_spatial

assert out.video.shape == (num_patches, patch_vec_dim)
assert out.latent_shape.dtype == torch.int32
assert torch.equal(out.latent_shape, torch.tensor([F_patches, H_patches, W_patches], dtype=torch.int32))

# Loss mask and seq lengths
assert out.loss_mask.dtype == torch.bfloat16
assert out.loss_mask.shape[0] == num_patches
assert torch.equal(out.seq_len_q, torch.tensor([num_patches], dtype=torch.int32))
# context embeddings are padded to fixed 512 inside encode_sample
assert torch.equal(out.seq_len_kv, torch.tensor([512], dtype=torch.int32))
assert torch.equal(out.seq_len_q_padded, out.seq_len_q)
assert torch.equal(out.seq_len_kv_padded, out.seq_len_kv)

# Metadata passthrough
assert out.video_metadata == sample["json"]
assert out.__key__ == sample["__key__"]
assert out.__restore_key__ == sample["__restore_key__"]
assert out.__subflavors__ == sample["__subflavors__"]


def test_batch_with_packing_buffer_size(monkeypatch):
# Force CP world size 1
monkeypatch.setattr(parallel_state, "get_context_parallel_world_size", lambda: 1, raising=False)
# Ensure seeded wrapper has an active worker config
from megatron.energon.task_encoder.base import WorkerConfig

class _FakeWorkerCfg:
def worker_seed(self):
return 456

active_worker_sample_index = 0

monkeypatch.setattr(WorkerConfig, "active_worker_config", _FakeWorkerCfg(), raising=False)

c = 4
F_latents, H_latents, W_latents = 2, 4, 4
patch_temporal, patch_spatial = 1, 2
video_latent = torch.randn(c, F_latents * patch_temporal, H_latents * patch_spatial, W_latents * patch_spatial)
sample = {
"__key__": "k",
"__restore_key__": "rk",
"__subflavors__": [],
"json": {"meta": 1},
"pth": video_latent,
"pickle": torch.randn(32, 128),
}

enc = WanTaskEncoder(
seq_length=256, patch_temporal=patch_temporal, patch_spatial=patch_spatial, packing_buffer_size=3
)
diff_sample = enc.encode_sample(sample)
batch = enc.batch([diff_sample])

assert isinstance(batch, dict)
for k in [
"video_latents",
"context_embeddings",
"loss_mask",
"seq_len_q",
"seq_len_q_padded",
"seq_len_kv",
"seq_len_kv_padded",
"grid_sizes",
"video_metadata",
]:
assert k in batch

# video_latents: [S, 1, ...], where S equals sample.video length when CP world size is 1
assert batch["video_latents"].shape[1] == 1
assert batch["context_embeddings"].shape[1] == 1
assert batch["loss_mask"].shape[1] == 1
Loading