Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge
Submodule Megatron-Bridge updated 169 files
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-LM-workspace/Megatron-LM
Submodule Megatron-LM updated 330 files
11 changes: 5 additions & 6 deletions 3rdparty/Megatron-LM-workspace/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
# Dev dependencies from pyproject.toml
"nvidia-modelopt[torch]; sys_platform != 'darwin'",
# TODO(https://github.com/NVIDIA-NeMo/RL/issues/2111): upgrade to core_cu13 when we move to CUDA 13 base container
"transformer-engine[pytorch,core_cu12]",
# VCS dependency - must match pyproject.toml [tool.uv.sources]
"transformer-engine[pytorch,core_cu12]>=2.9.0a0,<2.12.0",
# VCS dependency - must match Megatron-LM/pyproject.toml [tool.uv.sources]
"nvidia-resiliency-ext @ git+https://github.com/NVIDIA/nvidia-resiliency-ext.git@15a851565a4ce846c04431ecb0cf09903ab4837e",
"tqdm",
"einops~=0.8",
Expand All @@ -61,19 +61,18 @@
"mamba-ssm~=2.2",
"causal-conv1d~=1.5",
"flash-linear-attention~=0.4.0",
"nv-grouped-gemm~=1.1",
"megatron-energon[av_decode]~=6.0",
"av",
"flashinfer-python~=0.5.0",
"wget",
"onnxscript",
"fastapi~=0.50",
"datasets",
# VCS dependency - must match pyproject.toml [tool.uv.sources]
"emerging_optimizers @ git+https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git@v0.2.0",
"hypercorn",
"emerging_optimizers; python_version >= '3.12'",
"quart",
"hypercorn",
"openai[aiohttp]",
"orjson",
]


Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/distributed/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,7 @@ def _get_tokens_on_this_cp_rank(

for ind in shard_inds:
slices[seq_dim] = slice(ind * shard_size, (ind + 1) * shard_size)
ids_chunks.append(input_ids[slices])
ids_chunks.append(input_ids[tuple(slices)])

ids = torch.cat(ids_chunks, dim=seq_dim)
return ids
Expand Down
2 changes: 2 additions & 0 deletions nemo_rl/models/megatron/community_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def import_model_from_hf_name(
model_provider.gradient_accumulation_fusion = megatron_config[
"gradient_accumulation_fusion"
]
if megatron_config["context_parallel_size"] > 1:
model_provider.calculate_per_token_loss = True
model_provider.finalize()

from megatron.core import parallel_state
Expand Down
201 changes: 181 additions & 20 deletions nemo_rl/models/megatron/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def make_processed_microbatch_iterator(
ProcessedMicrobatch objects containing processed tensors ready for model forward
"""
pack_sequences = cfg["sequence_packing"]["enabled"]
delegate_pack_to_model = cfg["sequence_packing"].get(
"delegate_pack_to_model", False
)

for data_dict in raw_iterator:
# Move to GPU
Expand All @@ -109,6 +112,7 @@ def make_processed_microbatch_iterator(
pad_packed_seq_to_multiple_of=pad_packed_seq_to_multiple_of,
pad_full_seq_to=pad_full_seq_to,
pack_sequences=pack_sequences,
delegate_pack_to_model=delegate_pack_to_model,
straggler_timer=straggler_timer,
)

Expand Down Expand Up @@ -213,6 +217,7 @@ def process_microbatch(
pad_packed_seq_to_multiple_of: int = 1,
pad_full_seq_to: Optional[int] = None,
pack_sequences: bool = False,
delegate_pack_to_model: bool = False,
straggler_timer: Optional[StragglerDetector] = None,
) -> tuple[
torch.Tensor,
Expand Down Expand Up @@ -248,27 +253,48 @@ def process_microbatch(
# Get sequence lengths and context parallel size
seq_lengths = data_dict[seq_length_key]

# Pack sequences
(
input_ids,
input_ids_cp_sharded,
packed_seq_params,
cu_seqlens,
cu_seqlens_padded,
) = _pack_sequences_for_megatron(
input_ids,
seq_lengths,
pad_individual_seqs_to_multiple_of,
pad_packed_seq_to_multiple_of,
pad_full_seq_to,
cp_rank=get_context_parallel_rank(),
cp_size=get_context_parallel_world_size(),
)
if delegate_pack_to_model:
# VLM path: model (e.g. mbridge Qwen3VL) does its own
# preprocess_packed_seqs; NeMo-RL must NOT pre-pack + CP-shard,
# or the double-processing produces shape mismatches downstream
# (GDN/RoPE/MoE). We only pad each sequence individually and
# hand the model [B, max_seq] + bool attention_mask + cu_seqlens.
(
input_ids,
input_ids_cp_sharded,
attention_mask,
packed_seq_params,
cu_seqlens,
cu_seqlens_padded,
) = _prepare_vlm_batch_for_megatron(
input_ids,
seq_lengths,
pad_individual_seqs_to_multiple_of,
pad_full_seq_to=pad_full_seq_to,
)
position_ids = None
else:
# Pack sequences
(
input_ids,
input_ids_cp_sharded,
packed_seq_params,
cu_seqlens,
cu_seqlens_padded,
) = _pack_sequences_for_megatron(
input_ids,
seq_lengths,
pad_individual_seqs_to_multiple_of,
pad_packed_seq_to_multiple_of,
pad_full_seq_to,
cp_rank=get_context_parallel_rank(),
cp_size=get_context_parallel_world_size(),
)

# For packed sequences, position_ids and attention_mask are typically None
# The PackedSeqParams handles all necessary sequence information
position_ids = None
attention_mask = None
# For packed sequences, position_ids and attention_mask are typically None
# The PackedSeqParams handles all necessary sequence information
position_ids = None
attention_mask = None
else:
input_ids_cp_sharded = input_ids
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
Expand Down Expand Up @@ -343,6 +369,141 @@ def process_global_batch(
}


def _prepare_vlm_batch_for_megatron(
input_ids: torch.Tensor,
seq_lengths: torch.Tensor,
pad_individual_seqs_to_multiple_of: int,
pad_full_seq_to: Optional[int] = None,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
PackedSeqParams,
torch.Tensor,
torch.Tensor,
]:
"""Prepare a [B, max_seq] batch for a model that does its own packing + CP sharding.

Used with mbridge VLM wrappers (e.g. Qwen3VL). The model's forward calls
preprocess_packed_seqs internally, which re-packs + CP-shards from
attention_mask. So NeMo-RL must NOT pre-pack / CP-shard; it only:
* pads each sequence (along dim 1) to pad_individual_seqs_to_multiple_of,
* builds a bool attention_mask describing real token validity,
* builds cu_seqlens_padded describing full (pre-shard) packed layout,
* hands everything to the model as [B, max_seq].

When ``pad_full_seq_to`` is set (PP>1 requires a constant total packed
length across microbatches), the last sequence's effective length is
extended so ``sum(padded_lens) == pad_full_seq_to``. These extra positions
are treated as "valid" by the model (so mbridge's internal packing stays
consistent) but should be masked out at the loss layer via token_mask.

Returns:
- input_ids: packed [1, T] view for downstream logprob/loss target slicing
- input_ids_cp_sharded: [B, padded_max_seq] for the model forward
- attention_mask: [B, padded_max_seq] bool (True for valid tokens)
- packed_seq_params: PackedSeqParams(qkv_format="thd", cu_seqlens_*=padded)
- cu_seqlens: None (unpadded cu_seqlens unused in this path)
- cu_seqlens_padded: [B+1] int32 matching packed_seq_params
"""
batch_size, _ = input_ids.shape
device = input_ids.device
align = max(1, pad_individual_seqs_to_multiple_of)

# One CPU-GPU sync per call via .tolist(); per-seq arithmetic runs on CPU
# ints (fast) instead of .item() in a loop (which sync'd per seq).
if torch.is_tensor(seq_lengths):
lengths_list = seq_lengths.tolist()
else:
lengths_list = list(seq_lengths)
padded_lens = [((L + align - 1) // align) * align for L in lengths_list]

# PP>1: force sum(padded_lens) to a fixed value so every microbatch produces
# the same decoder-side packed length. We mirror _pack_sequences_for_megatron
# by absorbing the deficit into the LAST sequence's effective length. The
# extra positions look valid to the model but are zero-ed out at the loss
# layer via token_mask (consistent with the non-VLM path).
if pad_full_seq_to is not None and batch_size > 0:
natural_sum = sum(padded_lens)
deficit = pad_full_seq_to - natural_sum
assert deficit >= 0, (
f"pad_full_seq_to ({pad_full_seq_to}) < natural padded sum "
f"({natural_sum}); increase pad_full_seq_to."
)
assert deficit % align == 0, (
f"pad_full_seq_to deficit ({deficit}) must be a multiple of "
f"pad_individual_seqs_to_multiple_of ({align})."
)
if deficit > 0:
lengths_list[-1] += deficit
padded_lens[-1] += deficit

padded_max = max(padded_lens) if padded_lens else 0

# Row-pad input_ids to padded_max so all sequences live in one rectangular tensor.
if input_ids.shape[1] < padded_max:
pad_amt = padded_max - input_ids.shape[1]
input_ids_2d = torch.nn.functional.pad(input_ids, (0, pad_amt), value=0)
elif input_ids.shape[1] > padded_max:
input_ids_2d = input_ids[:, :padded_max].contiguous()
else:
input_ids_2d = input_ids

# Vectorised attention_mask: positions < padded length, broadcast over batch.
# We use padded_lens (not raw lengths) so mbridge's preprocess_packed_seqs,
# which recomputes seqlens from attention_mask.sum, sees the same packed
# total as our cu_seqlens_padded. Otherwise a mismatch between raw length
# and align-padded length leads to GDN's cu_seqlens vs total_seq_len check
# firing. Tokens in the padded tail are masked out at the loss layer.
padded_lens_tensor = torch.tensor(padded_lens, dtype=torch.long, device=device)
positions = torch.arange(padded_max, device=device)
attention_mask = positions.unsqueeze(0) < padded_lens_tensor.unsqueeze(1)

# Build cu_seqlens on CPU then H2D once.
cu_vals = [0]
for p in padded_lens:
cu_vals.append(cu_vals[-1] + p)
cu_seqlens_padded = torch.tensor(cu_vals, dtype=torch.int32, device=device)

packed_seq_params = PackedSeqParams(
qkv_format="thd",
cu_seqlens_q=cu_seqlens_padded,
cu_seqlens_kv=cu_seqlens_padded,
cu_seqlens_q_padded=cu_seqlens_padded,
cu_seqlens_kv_padded=cu_seqlens_padded,
max_seqlen_q=padded_max,
max_seqlen_kv=padded_max,
)

# Packed (unsharded) view for downstream logprob / loss code that slices
# per-sequence targets via cu_seqlens_padded. If all sequences are already
# padded to the same length (common case: input_ids_2d columns == padded_max
# and every padded_len == padded_max), we can reshape instead of Python-
# loop + cat, which avoids B separate GPU slice ops.
if padded_lens and all(p == padded_max for p in padded_lens):
packed_input_ids = input_ids_2d.reshape(1, -1)
else:
packed_segments = [input_ids_2d[i, :p] for i, p in enumerate(padded_lens)]
packed_input_ids = (
torch.cat(packed_segments, dim=0).unsqueeze(0)
if packed_segments
else input_ids_2d.new_zeros((1, 0))
)

# input_ids_cp_sharded keeps the [B, max_seq] layout: the model (mbridge
# Qwen3VL) runs its own preprocess_packed_seqs to pack + CP-shard.
# input_ids is the packed (but not CP-sharded) view for target/logprob
# post-processing, which uses cu_seqlens_padded to slice per sequence.
return (
packed_input_ids,
input_ids_2d,
attention_mask,
packed_seq_params,
None,
cu_seqlens_padded,
)


def _pack_sequences_for_megatron(
input_ids: torch.Tensor,
seq_lengths: torch.Tensor,
Expand Down
8 changes: 7 additions & 1 deletion nemo_rl/models/megatron/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,14 @@ def _apply_parallelism_config(model_cfg: Any, config: PolicyConfig) -> None:
model_cfg.context_parallel_size = config["megatron_cfg"]["context_parallel_size"]

if model_cfg.context_parallel_size > 1:
# Either NeMo-RL does the packing+CP-sharding itself (classic mcore
# GPTModel path) OR the model does it internally (mbridge VLM wrappers
# with delegate_pack_to_model=True). Both paths require cu_seqlens to
# flow via PackedSeqParams, so sequence_packing.enabled must be on.
assert config["sequence_packing"]["enabled"], (
"Sequence Packing must be enabled to use Context Parallelism with MCore"
"Sequence Packing must be enabled to use Context Parallelism with MCore "
"(set delegate_pack_to_model: true under sequence_packing for VLM models "
"where the model does its own packing)."
)
assert not config["megatron_cfg"].get("use_linear_ce_fusion_loss", False), (
"Context Parallelism is not supported with linear CE fusion loss, please set use_linear_ce_fusion_loss to false"
Expand Down
Loading
Loading