Skip to content

Sequence packing + no dynamic batching failure on 8b #648

@parthchadha

Description

@parthchadha

Describe the bug

▶ Computing logprobs...
                ^^^^^^^^^^^
  File "/lustre/fs1/portfolios/coreai/users/pchadha/gh/NeMo-Reinforcer/nemo_rl/distributed/batched_data_dict.py", line 729, in make_microbatch_iterator
    assert bsize % microbatch_size == 0, (
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Data dict size (249) is not a multiple of the provided microbatch size (2)

Steps/Code to reproduce bug

#300
+

diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml
index af748aad..e1f577b7 100644
--- a/examples/configs/grpo_math_1B.yaml
+++ b/examples/configs/grpo_math_1B.yaml
@@ -52,6 +52,9 @@ policy:
     context_parallel_size: 1
     custom_parallel_plan: null
 
+  megatron_cfg:
+    enabled: false
+
   # dynamic_batching improves performance by ensuring logprob and training microbatches
   # have a sufficent number of tokens to maximize GPU utilization. Specifically, variable length
   # responses are sorted by sequence length and bucketed into microbatches with a total
@@ -64,7 +67,7 @@ policy:
     sequence_length_round: 64
 
   sequence_packing:
-    enabled: False
+    enabled: True
     train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
     logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
     algorithm: "concatenative"
diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml
index 06ae6b46..10072b21 100644
--- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml
+++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml
@@ -49,6 +49,12 @@ policy:
     train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
     logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
     sequence_length_round: 64
+  sequence_packing:
+    enabled: True
+    train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
+    logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
+    algorithm: "concatenative"
+    sequence_length_round: 64
   make_sequence_length_divisible_by: 8
   max_grad_norm: 1
   optimizer:
@@ -76,7 +82,7 @@ policy:
         - 13
   generation:
     backend: vllm
-    max_new_tokens: 16384
+    max_new_tokens: ${policy.max_total_sequence_length}
     temperature: 1
     top_p: 1
     top_k: null
@@ -89,7 +95,7 @@ policy:
       tensor_parallel_size: 4
       pipeline_parallel_size: 1
       gpu_memory_utilization: 0.6
-      max_model_len: 16384
+      max_model_len: ${policy.max_total_sequence_length}
       enforce_eager: False
     colocated:
       enabled: true
@@ -97,7 +103,7 @@ policy:
         gpus_per_node: null
         num_nodes: null
 data:
-  max_input_seq_length: 16384
+  max_input_seq_length: ${policy.max_total_sequence_length}
   prompt_file: examples/prompts/cot.txt
   system_prompt_file: null
   dataset_name: OpenMathInstruct-2
diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py
index 14f1e6a5..814a5900 100644
--- a/nemo_rl/models/policy/lm_policy.py
+++ b/nemo_rl/models/policy/lm_policy.py
@@ -152,8 +152,8 @@ class Policy(ColocatablePolicyInterface, GenerationInterface):
                 "algorithm": config["sequence_packing"]["algorithm"],
                 "input_key": "input_ids",
                 "input_lengths_key": "input_lengths",
-                "sequence_length_pad_multiple": (self.cp_size * 2 * tp_size)
-                if self.cp_size > 1
+                "sequence_length_pad_multiple": (cp_size * 2 * tp_size)
+                if cp_size > 1
                 else tp_size,
             }
         else:

uv run ./examples/run_grpo_math.py --config=examples/configs/grpo_math_8B.yaml cluster.num_nodes=1 cluster.gpus_per_node=8 policy.dynamic_batching.enabled=False checkpointing.enabled=false

Expected behavior

A clear and concise description of what you expected to happen.

Environment overview (please complete the following information)

  • Environment location: [Bare-metal, Docker, Cloud(specify cloud provider - AWS, Azure, GCP, Collab)]
  • Method of install: [pip install or from source]. Please specify exact commands you used to install.
  • If method of install is [Docker], provide docker pull & docker run commands used

Environment details

If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:

  • OS version
  • PyTorch version
  • Python version

Additional context

Add any other context about the problem here.
Example: GPU model

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions