From c49a745e808262d035896a43251f06726a1cc25b Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Wed, 22 Apr 2026 08:24:46 -0700 Subject: [PATCH 1/3] chore: bump Megatron-Bridge to latest main and Megatron-LM to dev MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Megatron-Bridge: 95e5f38f → 53f4c398 (latest main) - Megatron-LM: d30c3ae5 → 546a448b (dev) The dev-branch Megatron-LM is required for Qwen3.5 VL's GDN + context parallelism support (NVIDIA/Megatron-LM#2642, NVIDIA/Megatron-LM#2644). Sync the proxy setup.py CACHED_DEPENDENCIES to match the updated dev pyproject.toml (nv-grouped-gemm, flash_mla, nvidia-resiliency-ext VCS pin, emerging_optimizers python-version marker, etc.) and regenerate uv.lock. Signed-off-by: Zhaopeng Qiu --- .../Megatron-Bridge-workspace/Megatron-Bridge | 2 +- 3rdparty/Megatron-LM-workspace/Megatron-LM | 2 +- 3rdparty/Megatron-LM-workspace/setup.py | 11 +++++----- uv.lock | 21 ++++++++++--------- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge b/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge index 95e5f38f87..53f4c398f1 160000 --- a/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge +++ b/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge @@ -1 +1 @@ -Subproject commit 95e5f38f8727c4ab30830559c68939f35f4e52f6 +Subproject commit 53f4c398f16f36e8afd4cc6676b1f87faed74be0 diff --git a/3rdparty/Megatron-LM-workspace/Megatron-LM b/3rdparty/Megatron-LM-workspace/Megatron-LM index d30c3ae546..546a448b4b 160000 --- a/3rdparty/Megatron-LM-workspace/Megatron-LM +++ b/3rdparty/Megatron-LM-workspace/Megatron-LM @@ -1 +1 @@ -Subproject commit d30c3ae5469fe3f6a64d4fd2e63b6e7f7844ea81 +Subproject commit 546a448b4bf0987966826f55be21142a0ed7dd74 diff --git a/3rdparty/Megatron-LM-workspace/setup.py b/3rdparty/Megatron-LM-workspace/setup.py index 75b5831fb4..fc7f6b09fd 100644 --- a/3rdparty/Megatron-LM-workspace/setup.py +++ b/3rdparty/Megatron-LM-workspace/setup.py @@ -49,8 +49,8 @@ # Dev dependencies from pyproject.toml "nvidia-modelopt[torch]; sys_platform != 'darwin'", # TODO(https://github.com/NVIDIA-NeMo/RL/issues/2111): upgrade to core_cu13 when we move to CUDA 13 base container - "transformer-engine[pytorch,core_cu12]", - # VCS dependency - must match pyproject.toml [tool.uv.sources] + "transformer-engine[pytorch,core_cu12]>=2.9.0a0,<2.12.0", + # VCS dependency - must match Megatron-LM/pyproject.toml [tool.uv.sources] "nvidia-resiliency-ext @ git+https://github.com/NVIDIA/nvidia-resiliency-ext.git@15a851565a4ce846c04431ecb0cf09903ab4837e", "tqdm", "einops~=0.8", @@ -61,6 +61,7 @@ "mamba-ssm~=2.2", "causal-conv1d~=1.5", "flash-linear-attention~=0.4.0", + "nv-grouped-gemm~=1.1", "megatron-energon[av_decode]~=6.0", "av", "flashinfer-python~=0.5.0", @@ -68,12 +69,10 @@ "onnxscript", "fastapi~=0.50", "datasets", - # VCS dependency - must match pyproject.toml [tool.uv.sources] - "emerging_optimizers @ git+https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git@v0.2.0", - "hypercorn", + "emerging_optimizers; python_version >= '3.12'", "quart", + "hypercorn", "openai[aiohttp]", - "orjson", ] diff --git a/uv.lock b/uv.lock index 85da57b397..d2e4a455b1 100644 --- a/uv.lock +++ b/uv.lock @@ -234,11 +234,11 @@ wheels = [ [[package]] name = "aiofiles" -version = "24.1.0" +version = "25.1.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0b/03/a88171e277e8caa88a4c77808c20ebb04ba74cc4681bf1e9416c862de237/aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c", size = 30247, upload-time = "2024-06-24T11:02:03.584Z" } +sdist = { url = "https://files.pythonhosted.org/packages/41/c3/534eac40372d8ee36ef40df62ec129bee4fdb5ad9706e58a29be53b2c970/aiofiles-25.1.0.tar.gz", hash = "sha256:a8d728f0a29de45dc521f18f07297428d56992a742f0cd2701ba86e44d23d5b2", size = 46354, upload-time = "2025-10-09T20:51:04.358Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a5/45/30bb92d442636f570cb5651bc661f52b610e2eec3f891a5dc3a4c3667db0/aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5", size = 15896, upload-time = "2024-06-24T11:02:01.529Z" }, + { url = "https://files.pythonhosted.org/packages/bc/8a/340a1555ae33d7354dbca4faa54948d76d89a27ceef032c8c3bc661d003e/aiofiles-25.1.0-py3-none-any.whl", hash = "sha256:abe311e527c862958650f9438e859c1fa7568a141b22abcd015e120e86a85695", size = 14668, upload-time = "2025-10-09T20:51:03.174Z" }, ] [[package]] @@ -3348,13 +3348,13 @@ dependencies = [ { name = "multi-storage-client" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-fsdp') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-mcore') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-fsdp' and extra != 'extra-7-nemo-rl-mcore') or (extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-vllm') or (extra == 'extra-7-nemo-rl-sglang' and extra == 'extra-7-nemo-rl-vllm') or (extra != 'extra-7-nemo-rl-automodel' and extra != 'extra-7-nemo-rl-mcore' and extra != 'extra-7-nemo-rl-sglang') or (extra != 'extra-7-nemo-rl-fsdp' and extra != 'extra-7-nemo-rl-mcore' and extra != 'extra-7-nemo-rl-sglang' and extra == 'extra-7-nemo-rl-vllm')" }, { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-7-nemo-rl-automodel' or extra == 'extra-7-nemo-rl-mcore' or extra == 'extra-7-nemo-rl-sglang'" }, + { name = "nv-grouped-gemm" }, { name = "nvidia-modelopt" }, { name = "nvidia-resiliency-ext" }, { name = "nvtx" }, { name = "onnxscript" }, { name = "openai", extra = ["aiohttp"] }, { name = "opentelemetry-api" }, - { name = "orjson" }, { name = "packaging" }, { name = "quart" }, { name = "tensorstore" }, @@ -3371,7 +3371,7 @@ requires-dist = [ { name = "causal-conv1d", git = "https://github.com/Dao-AILab/causal-conv1d?rev=67e0a9dfe1518fc0036444e9ab5fe06ab78299e0" }, { name = "datasets" }, { name = "einops", specifier = "~=0.8" }, - { name = "emerging-optimizers", git = "https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git?rev=v0.2.0" }, + { name = "emerging-optimizers", marker = "python_full_version >= '3.12'", git = "https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git?rev=v0.2.0" }, { name = "fastapi", specifier = "~=0.50" }, { name = "flash-linear-attention", specifier = "~=0.4.0" }, { name = "flashinfer-python", specifier = "~=0.5.0" }, @@ -3380,20 +3380,20 @@ requires-dist = [ { name = "megatron-energon", extras = ["av-decode"], specifier = "~=6.0" }, { name = "multi-storage-client", specifier = "~=0.27" }, { name = "numpy" }, + { name = "nv-grouped-gemm", git = "https://github.com/fanshiqing/grouped_gemm?tag=v1.1.4.post7" }, { name = "nvidia-modelopt", extras = ["torch"], marker = "sys_platform != 'darwin'" }, { name = "nvidia-resiliency-ext", git = "https://github.com/NVIDIA/nvidia-resiliency-ext.git?rev=15a851565a4ce846c04431ecb0cf09903ab4837e" }, { name = "nvtx", specifier = "~=0.2" }, { name = "onnxscript" }, { name = "openai", extras = ["aiohttp"] }, { name = "opentelemetry-api", specifier = "~=1.33.1" }, - { name = "orjson" }, { name = "packaging", specifier = ">=24.2" }, { name = "quart" }, { name = "tensorstore", specifier = "~=0.1,!=0.1.46,!=0.1.72" }, { name = "torch", marker = "sys_platform != 'darwin'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cu129" }, { name = "torch", marker = "sys_platform == 'darwin'", specifier = ">=2.6.0", index = "https://pypi.org/simple" }, { name = "tqdm" }, - { name = "transformer-engine", extras = ["core-cu12", "pytorch"] }, + { name = "transformer-engine", extras = ["core-cu12", "pytorch"], specifier = ">=2.9.0a0,<2.12.0" }, { name = "wget" }, ] @@ -4568,10 +4568,11 @@ name = "nv-grouped-gemm" version = "1.1.4.post7" source = { git = "https://github.com/fanshiqing/grouped_gemm?tag=v1.1.4.post7#6dfaf60e6112166b8b82e9210b51c7f557956f0a" } dependencies = [ - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" } }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-fsdp') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-mcore') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-fsdp' and extra != 'extra-7-nemo-rl-mcore') or (extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-vllm') or (extra == 'extra-7-nemo-rl-sglang' and extra == 'extra-7-nemo-rl-vllm') or (extra != 'extra-7-nemo-rl-automodel' and extra != 'extra-7-nemo-rl-mcore' and extra != 'extra-7-nemo-rl-sglang') or (extra != 'extra-7-nemo-rl-fsdp' and extra != 'extra-7-nemo-rl-mcore' and extra != 'extra-7-nemo-rl-sglang' and extra == 'extra-7-nemo-rl-vllm')" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-7-nemo-rl-automodel' or extra == 'extra-7-nemo-rl-mcore' or extra == 'extra-7-nemo-rl-sglang'" }, { name = "setuptools" }, - { name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform == 'darwin' and extra == 'extra-7-nemo-rl-automodel') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-fsdp') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-mcore') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-vllm') or (extra == 'extra-7-nemo-rl-fsdp' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-vllm') or (extra == 'extra-7-nemo-rl-sglang' and extra == 'extra-7-nemo-rl-vllm')" }, - { name = "torch", version = "2.10.0+cu129", source = { registry = "https://download.pytorch.org/whl/cu129" }, marker = "(sys_platform != 'darwin' and extra == 'extra-7-nemo-rl-automodel') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-fsdp') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-mcore') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-vllm') or (extra == 'extra-7-nemo-rl-fsdp' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-vllm') or (extra == 'extra-7-nemo-rl-sglang' and extra == 'extra-7-nemo-rl-vllm')" }, + { name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin' or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-fsdp') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-mcore') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-vllm') or (extra == 'extra-7-nemo-rl-fsdp' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-vllm') or (extra == 'extra-7-nemo-rl-sglang' and extra == 'extra-7-nemo-rl-vllm')" }, + { name = "torch", version = "2.10.0+cu129", source = { registry = "https://download.pytorch.org/whl/cu129" }, marker = "sys_platform != 'darwin' or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-fsdp') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-mcore') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-vllm') or (extra == 'extra-7-nemo-rl-fsdp' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-sglang') or (extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-vllm') or (extra == 'extra-7-nemo-rl-sglang' and extra == 'extra-7-nemo-rl-vllm')" }, { name = "wheel" }, ] From 2c2c6e4a897c73f4f7b7541c04bee374fa2b4717 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Wed, 22 Apr 2026 08:28:04 -0700 Subject: [PATCH 2/3] feat(megatron): delegate packing to the model for VLM + CP (Qwen3.5 VL) Qwen3.5 VL's mbridge wrapper (Qwen3VLModel) always runs its own preprocess_packed_seqs internally to pack + CP-shard from a [B, max_seq] input + attention_mask. NeMo-RL's existing path pre-packs and CP-shards before the forward, which collides with mbridge's preprocessing and produces shape mismatches at GDN / RoPE / MoE when CP > 1. Add a sequence_packing.delegate_pack_to_model flag. When true, _prepare_vlm_batch_for_megatron keeps the batch in [B, max_seq] layout, builds a bool attention_mask from the (aligned) padded lengths, and hands the model a PackedSeqParams whose cu_seqlens_padded matches what mbridge will derive internally. The model owns packing and CP-sharding from there. For the target-side path (logprob / loss post-processing), we also produce a packed [1, T] view of input_ids; downstream code already slices per-sequence via cu_seqlens_padded. PP > 1 is supported by absorbing the pad_full_seq_to deficit into the last sequence (same technique as _pack_sequences_for_megatron), so the decoder-side packed length is constant across microbatches. Additional fixes needed for this path: - community_import.py: set calculate_per_token_loss when CP > 1, which Qwen3VLModel asserts. - setup.py: clarify the 'CP > 1 requires sequence_packing' error message to mention delegate_pack_to_model for VLM models. Signed-off-by: Zhaopeng Qiu --- nemo_rl/models/megatron/community_import.py | 2 + nemo_rl/models/megatron/data.py | 201 ++++++++++++++++++-- nemo_rl/models/megatron/setup.py | 8 +- 3 files changed, 190 insertions(+), 21 deletions(-) diff --git a/nemo_rl/models/megatron/community_import.py b/nemo_rl/models/megatron/community_import.py index c91adcd606..5a8aae0b32 100644 --- a/nemo_rl/models/megatron/community_import.py +++ b/nemo_rl/models/megatron/community_import.py @@ -78,6 +78,8 @@ def import_model_from_hf_name( model_provider.gradient_accumulation_fusion = megatron_config[ "gradient_accumulation_fusion" ] + if megatron_config["context_parallel_size"] > 1: + model_provider.calculate_per_token_loss = True model_provider.finalize() from megatron.core import parallel_state diff --git a/nemo_rl/models/megatron/data.py b/nemo_rl/models/megatron/data.py index fd68ce2fbc..aadf74098d 100644 --- a/nemo_rl/models/megatron/data.py +++ b/nemo_rl/models/megatron/data.py @@ -96,6 +96,9 @@ def make_processed_microbatch_iterator( ProcessedMicrobatch objects containing processed tensors ready for model forward """ pack_sequences = cfg["sequence_packing"]["enabled"] + delegate_pack_to_model = cfg["sequence_packing"].get( + "delegate_pack_to_model", False + ) for data_dict in raw_iterator: # Move to GPU @@ -109,6 +112,7 @@ def make_processed_microbatch_iterator( pad_packed_seq_to_multiple_of=pad_packed_seq_to_multiple_of, pad_full_seq_to=pad_full_seq_to, pack_sequences=pack_sequences, + delegate_pack_to_model=delegate_pack_to_model, straggler_timer=straggler_timer, ) @@ -213,6 +217,7 @@ def process_microbatch( pad_packed_seq_to_multiple_of: int = 1, pad_full_seq_to: Optional[int] = None, pack_sequences: bool = False, + delegate_pack_to_model: bool = False, straggler_timer: Optional[StragglerDetector] = None, ) -> tuple[ torch.Tensor, @@ -248,27 +253,48 @@ def process_microbatch( # Get sequence lengths and context parallel size seq_lengths = data_dict[seq_length_key] - # Pack sequences - ( - input_ids, - input_ids_cp_sharded, - packed_seq_params, - cu_seqlens, - cu_seqlens_padded, - ) = _pack_sequences_for_megatron( - input_ids, - seq_lengths, - pad_individual_seqs_to_multiple_of, - pad_packed_seq_to_multiple_of, - pad_full_seq_to, - cp_rank=get_context_parallel_rank(), - cp_size=get_context_parallel_world_size(), - ) + if delegate_pack_to_model: + # VLM path: model (e.g. mbridge Qwen3VL) does its own + # preprocess_packed_seqs; NeMo-RL must NOT pre-pack + CP-shard, + # or the double-processing produces shape mismatches downstream + # (GDN/RoPE/MoE). We only pad each sequence individually and + # hand the model [B, max_seq] + bool attention_mask + cu_seqlens. + ( + input_ids, + input_ids_cp_sharded, + attention_mask, + packed_seq_params, + cu_seqlens, + cu_seqlens_padded, + ) = _prepare_vlm_batch_for_megatron( + input_ids, + seq_lengths, + pad_individual_seqs_to_multiple_of, + pad_full_seq_to=pad_full_seq_to, + ) + position_ids = None + else: + # Pack sequences + ( + input_ids, + input_ids_cp_sharded, + packed_seq_params, + cu_seqlens, + cu_seqlens_padded, + ) = _pack_sequences_for_megatron( + input_ids, + seq_lengths, + pad_individual_seqs_to_multiple_of, + pad_packed_seq_to_multiple_of, + pad_full_seq_to, + cp_rank=get_context_parallel_rank(), + cp_size=get_context_parallel_world_size(), + ) - # For packed sequences, position_ids and attention_mask are typically None - # The PackedSeqParams handles all necessary sequence information - position_ids = None - attention_mask = None + # For packed sequences, position_ids and attention_mask are typically None + # The PackedSeqParams handles all necessary sequence information + position_ids = None + attention_mask = None else: input_ids_cp_sharded = input_ids attention_mask, _, position_ids = get_ltor_masks_and_position_ids( @@ -343,6 +369,141 @@ def process_global_batch( } +def _prepare_vlm_batch_for_megatron( + input_ids: torch.Tensor, + seq_lengths: torch.Tensor, + pad_individual_seqs_to_multiple_of: int, + pad_full_seq_to: Optional[int] = None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + PackedSeqParams, + torch.Tensor, + torch.Tensor, +]: + """Prepare a [B, max_seq] batch for a model that does its own packing + CP sharding. + + Used with mbridge VLM wrappers (e.g. Qwen3VL). The model's forward calls + preprocess_packed_seqs internally, which re-packs + CP-shards from + attention_mask. So NeMo-RL must NOT pre-pack / CP-shard; it only: + * pads each sequence (along dim 1) to pad_individual_seqs_to_multiple_of, + * builds a bool attention_mask describing real token validity, + * builds cu_seqlens_padded describing full (pre-shard) packed layout, + * hands everything to the model as [B, max_seq]. + + When ``pad_full_seq_to`` is set (PP>1 requires a constant total packed + length across microbatches), the last sequence's effective length is + extended so ``sum(padded_lens) == pad_full_seq_to``. These extra positions + are treated as "valid" by the model (so mbridge's internal packing stays + consistent) but should be masked out at the loss layer via token_mask. + + Returns: + - input_ids: packed [1, T] view for downstream logprob/loss target slicing + - input_ids_cp_sharded: [B, padded_max_seq] for the model forward + - attention_mask: [B, padded_max_seq] bool (True for valid tokens) + - packed_seq_params: PackedSeqParams(qkv_format="thd", cu_seqlens_*=padded) + - cu_seqlens: None (unpadded cu_seqlens unused in this path) + - cu_seqlens_padded: [B+1] int32 matching packed_seq_params + """ + batch_size, _ = input_ids.shape + device = input_ids.device + align = max(1, pad_individual_seqs_to_multiple_of) + + # One CPU-GPU sync per call via .tolist(); per-seq arithmetic runs on CPU + # ints (fast) instead of .item() in a loop (which sync'd per seq). + if torch.is_tensor(seq_lengths): + lengths_list = seq_lengths.tolist() + else: + lengths_list = list(seq_lengths) + padded_lens = [((L + align - 1) // align) * align for L in lengths_list] + + # PP>1: force sum(padded_lens) to a fixed value so every microbatch produces + # the same decoder-side packed length. We mirror _pack_sequences_for_megatron + # by absorbing the deficit into the LAST sequence's effective length. The + # extra positions look valid to the model but are zero-ed out at the loss + # layer via token_mask (consistent with the non-VLM path). + if pad_full_seq_to is not None and batch_size > 0: + natural_sum = sum(padded_lens) + deficit = pad_full_seq_to - natural_sum + assert deficit >= 0, ( + f"pad_full_seq_to ({pad_full_seq_to}) < natural padded sum " + f"({natural_sum}); increase pad_full_seq_to." + ) + assert deficit % align == 0, ( + f"pad_full_seq_to deficit ({deficit}) must be a multiple of " + f"pad_individual_seqs_to_multiple_of ({align})." + ) + if deficit > 0: + lengths_list[-1] += deficit + padded_lens[-1] += deficit + + padded_max = max(padded_lens) if padded_lens else 0 + + # Row-pad input_ids to padded_max so all sequences live in one rectangular tensor. + if input_ids.shape[1] < padded_max: + pad_amt = padded_max - input_ids.shape[1] + input_ids_2d = torch.nn.functional.pad(input_ids, (0, pad_amt), value=0) + elif input_ids.shape[1] > padded_max: + input_ids_2d = input_ids[:, :padded_max].contiguous() + else: + input_ids_2d = input_ids + + # Vectorised attention_mask: positions < padded length, broadcast over batch. + # We use padded_lens (not raw lengths) so mbridge's preprocess_packed_seqs, + # which recomputes seqlens from attention_mask.sum, sees the same packed + # total as our cu_seqlens_padded. Otherwise a mismatch between raw length + # and align-padded length leads to GDN's cu_seqlens vs total_seq_len check + # firing. Tokens in the padded tail are masked out at the loss layer. + padded_lens_tensor = torch.tensor(padded_lens, dtype=torch.long, device=device) + positions = torch.arange(padded_max, device=device) + attention_mask = positions.unsqueeze(0) < padded_lens_tensor.unsqueeze(1) + + # Build cu_seqlens on CPU then H2D once. + cu_vals = [0] + for p in padded_lens: + cu_vals.append(cu_vals[-1] + p) + cu_seqlens_padded = torch.tensor(cu_vals, dtype=torch.int32, device=device) + + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + cu_seqlens_kv=cu_seqlens_padded, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=padded_max, + max_seqlen_kv=padded_max, + ) + + # Packed (unsharded) view for downstream logprob / loss code that slices + # per-sequence targets via cu_seqlens_padded. If all sequences are already + # padded to the same length (common case: input_ids_2d columns == padded_max + # and every padded_len == padded_max), we can reshape instead of Python- + # loop + cat, which avoids B separate GPU slice ops. + if padded_lens and all(p == padded_max for p in padded_lens): + packed_input_ids = input_ids_2d.reshape(1, -1) + else: + packed_segments = [input_ids_2d[i, :p] for i, p in enumerate(padded_lens)] + packed_input_ids = ( + torch.cat(packed_segments, dim=0).unsqueeze(0) + if packed_segments + else input_ids_2d.new_zeros((1, 0)) + ) + + # input_ids_cp_sharded keeps the [B, max_seq] layout: the model (mbridge + # Qwen3VL) runs its own preprocess_packed_seqs to pack + CP-shard. + # input_ids is the packed (but not CP-sharded) view for target/logprob + # post-processing, which uses cu_seqlens_padded to slice per sequence. + return ( + packed_input_ids, + input_ids_2d, + attention_mask, + packed_seq_params, + None, + cu_seqlens_padded, + ) + + def _pack_sequences_for_megatron( input_ids: torch.Tensor, seq_lengths: torch.Tensor, diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index fc5c6c44fa..6ec3ebdf30 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -386,8 +386,14 @@ def _apply_parallelism_config(model_cfg: Any, config: PolicyConfig) -> None: model_cfg.context_parallel_size = config["megatron_cfg"]["context_parallel_size"] if model_cfg.context_parallel_size > 1: + # Either NeMo-RL does the packing+CP-sharding itself (classic mcore + # GPTModel path) OR the model does it internally (mbridge VLM wrappers + # with delegate_pack_to_model=True). Both paths require cu_seqlens to + # flow via PackedSeqParams, so sequence_packing.enabled must be on. assert config["sequence_packing"]["enabled"], ( - "Sequence Packing must be enabled to use Context Parallelism with MCore" + "Sequence Packing must be enabled to use Context Parallelism with MCore " + "(set delegate_pack_to_model: true under sequence_packing for VLM models " + "where the model does its own packing)." ) assert not config["megatron_cfg"].get("use_linear_ce_fusion_loss", False), ( "Context Parallelism is not supported with linear CE fusion loss, please set use_linear_ce_fusion_loss to false" From 16ae3635d3b8d2521a5e4c2db1c7965834e02a60 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Wed, 22 Apr 2026 08:28:18 -0700 Subject: [PATCH 3/3] perf: use tuple indexing in CP zigzag slicing to silence PyTorch warning _get_tokens_on_this_cp_rank slices a tensor with a list of slices, which PyTorch 2.9 deprecates ('Using a non-tuple sequence for multidimensional indexing is deprecated'). Every GDN/attention layer triggers this on the packed-CP path, flooding the worker logs. Casting to tuple matches the recommended API and is functionally identical. Signed-off-by: Zhaopeng Qiu --- nemo_rl/distributed/model_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index dd46ff7a27..3b777d1ab0 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -1216,7 +1216,7 @@ def _get_tokens_on_this_cp_rank( for ind in shard_inds: slices[seq_dim] = slice(ind * shard_size, (ind + 1) * shard_size) - ids_chunks.append(input_ids[slices]) + ids_chunks.append(input_ids[tuple(slices)]) ids = torch.cat(ids_chunks, dim=seq_dim) return ids