Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ MODEL_PATH=/home/TestData/HF_HOME/hub/models--meta-llama--Llama-3.2-1B/snapshots
# override with a smaller model meta-llama/Llama-3.2-1B for testing
TRANSFORMERS_OFFLINE=1 python -m torch.distributed.run --nproc_per_node=2 --nnodes=1 -m coverage run \
nemo_automodel/recipes/llm/benchmark.py \
--config examples/llm_finetune/llama3_3/custom_llama3_3_70b_instruct_peft_benchmark.yaml \
--config examples/llm_benchmark/llama3_3/custom_llama3_3_70b_instruct_peft_benchmark.yaml \
--model.pretrained_model_name_or_path=${MODEL_PATH} \
--model.num_hidden_layers=2 \
--distributed.tp_size=2 \
--distributed.pp_size=1 \
--distributed_config.sequence_parallel=False \
--benchmark.warmup_steps=2 \
--step_scheduler.max_steps=4
--step_scheduler.max_steps=4
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ MODEL_PATH=/home/TestData/HF_HOME/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed
# override with a smaller model Qwen/Qwen2.5-1.5B for testing
TRANSFORMERS_OFFLINE=1 python -m torch.distributed.run --master-port=29513 --nproc_per_node=2 --nnodes=1 -m coverage run \
nemo_automodel/recipes/llm/benchmark.py \
--config examples/llm_finetune/qwen/custom_qwen2_5_32b_peft_benchmark.yaml \
--config examples/llm_benchmark/qwen/custom_qwen2_5_32b_peft_benchmark.yaml \
--model.pretrained_model_name_or_path=${MODEL_PATH} \
--model.num_hidden_layers=2 \
--distributed.tp_size=2 \
Expand All @@ -34,4 +34,4 @@ nemo_automodel/recipes/llm/benchmark.py \
--step_scheduler.max_steps=4 \
--step_scheduler.global_batch_size=2 \
--step_scheduler.local_batch_size=1 \
--dataset.seq_len=256
--dataset.seq_len=256
12 changes: 7 additions & 5 deletions tests/unit_tests/distributed/test_mesh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def test_no_replicate_no_cp_returns_dp_shard_slice(self):

result = get_fsdp_dp_mesh(mesh)

mesh.__getitem__.assert_called_once_with("dp_shard")
mesh.__getitem__.assert_any_call("dp_shard")
assert result._key == "dp_shard"
# Returned mesh is a slice of the original, not a freshly built one.
assert result._mesh is mesh
Expand All @@ -376,7 +376,7 @@ def test_replicate_only_returns_replicate_shard_tuple_slice(self):

result = get_fsdp_dp_mesh(mesh)

mesh.__getitem__.assert_called_once_with(("dp_replicate", "dp_shard"))
mesh.__getitem__.assert_any_call(("dp_replicate", "dp_shard"))
assert result._key == ("dp_replicate", "dp_shard")
assert result._mesh is mesh

Expand All @@ -390,7 +390,7 @@ def test_cp_only_returns_shard_cp_tuple_slice(self):

result = get_fsdp_dp_mesh(mesh)

mesh.__getitem__.assert_called_once_with(("dp_shard", "cp"))
mesh.__getitem__.assert_any_call(("dp_shard", "cp"))
assert result._key == ("dp_shard", "cp")
assert result._mesh is mesh

Expand All @@ -411,8 +411,10 @@ def test_replicate_and_cp_falls_back_to_get_submesh(self):

mock_get_submesh.assert_called_once_with(mesh, ("dp_replicate", "dp_shard_cp"))
assert result is submesh_sentinel
# __getitem__ must NOT have been called directly.
mesh.__getitem__.assert_not_called()
# The returned mesh must come from get_submesh, not from a direct
# __getitem__ slice. Size probes via __getitem__ are allowed.
direct_slice_calls = [c for c in mesh.__getitem__.call_args_list if isinstance(c.args[0], tuple)]
assert direct_slice_calls == []

# ------------------------------------------------------------------
# Branch 5 – native dims not available → get_submesh fallback
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,12 +583,12 @@ def mesh_tp2(self):
return mesh, dp_mesh, tp_mesh

def _mock_env(self, monkeypatch, dp_mesh_sentinel=None):
# Mock get_submesh to return the known dp_mesh sentinel from the fixture,
# Mock get_fsdp_dp_mesh to return the known dp_mesh sentinel from the fixture,
# so we can assert the correct mesh is forwarded to apply_fsdp.
if dp_mesh_sentinel is not None:
monkeypatch.setattr(
"nemo_automodel.components.distributed.parallelizer.get_submesh",
lambda mesh, names: dp_mesh_sentinel,
"nemo_automodel.components.distributed.parallelizer.get_fsdp_dp_mesh",
lambda mesh, *a, **kw: dp_mesh_sentinel,
)

fully_shard_mock = MagicMock(side_effect=lambda model, **kwargs: model)
Expand Down
28 changes: 15 additions & 13 deletions tests/unit_tests/distributed/test_parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,7 @@ def test_apply_fsdp_sharding_module_list(
self, mock_fully_shard, mock_module_list, mock_mesh, mock_mp_policy, mock_offload_policy
):
"""Test apply_fsdp2_sharding_recursively with a ModuleList."""

# Set up mock return values - add FSDP2 prefetch methods that fully_shard normally provides
def mock_shard(x, **kwargs):
x.set_modules_to_forward_prefetch = MagicMock()
Expand Down Expand Up @@ -761,6 +762,7 @@ def test_apply_fsdp_sharding_module_list_without_offload_policy(
self, mock_fully_shard, mock_module_list, mock_mesh, mock_mp_policy
):
"""Test apply_fsdp2_sharding_recursively with a ModuleList and no offload policy."""

# Set up mock return values - add FSDP2 prefetch methods that fully_shard normally provides
def mock_shard(x, **kwargs):
x.set_modules_to_forward_prefetch = MagicMock()
Expand Down Expand Up @@ -969,11 +971,15 @@ class via type(...) that preserves __module__ and __qualname__ from the original

# Simulate _get_mixin_wrapped_class: create a *new* class object that copies
# __module__ and __qualname__ from the original (same qualname, different object)
WrappedCls = type(original_cls.__name__, (nn.Module,), {
"forward": lambda self, x: x,
"__module__": original_cls.__module__,
"__qualname__": original_cls.__qualname__,
})
WrappedCls = type(
original_cls.__name__,
(nn.Module,),
{
"forward": lambda self, x: x,
"__module__": original_cls.__module__,
"__qualname__": original_cls.__qualname__,
},
)
assert WrappedCls is not original_cls
assert _get_class_qualname(WrappedCls) == _get_class_qualname(original_cls)

Expand Down Expand Up @@ -1276,8 +1282,8 @@ def forward(self, x):
lambda *a, **kw: None,
)
monkeypatch.setattr(
"nemo_automodel.components.distributed.parallelizer.get_submesh",
lambda mesh, names: MagicMock(),
"nemo_automodel.components.distributed.parallelizer.get_fsdp_dp_mesh",
lambda mesh, *a, **kw: MagicMock(),
)

def _run_parallelize(self, model, activation_checkpointing=True):
Expand Down Expand Up @@ -1313,17 +1319,13 @@ def test_use_cache_disabled_without_kv_sharing(self):

def test_use_cache_preserved_flat_config(self):
"""KV-sharing detected through a flat config (no text_config nesting)."""
model = _make_model_for_ac(
use_cache=True, num_kv_shared_layers=10, text_config_nested=False
)
model = _make_model_for_ac(use_cache=True, num_kv_shared_layers=10, text_config_nested=False)
self._run_parallelize(model)
assert model.config.use_cache is True

def test_use_cache_disabled_flat_config_no_sharing(self):
"""Flat config without KV sharing still disables cache."""
model = _make_model_for_ac(
use_cache=True, num_kv_shared_layers=0, text_config_nested=False
)
model = _make_model_for_ac(use_cache=True, num_kv_shared_layers=0, text_config_nested=False)
self._run_parallelize(model)
assert model.config.use_cache is False

Expand Down
Loading