diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 62325d7a03..422e869f56 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -27,7 +27,7 @@ policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" train_global_batch_size: 512 train_micro_batch_size: 4 - generation_batch_size: 32 + generation_batch_size: 32 # Only used when generating using HF backend logprob_batch_size: 4 max_total_sequence_length: 512 precision: "bfloat16" diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index d747e9249f..261db927b1 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -1,11 +1,15 @@ # GRPO Algorithm Configuration defaults: "grpo_math_1B.yaml" +grpo: + num_prompts_per_step: 64 + num_generations_per_prompt: 32 + policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" train_global_batch_size: 512 train_micro_batch_size: 1 - generation_batch_size: 32 + generation_batch_size: 32 # Only used when generating using HF backend logprob_batch_size: 2 max_total_sequence_length: 4096 precision: "bfloat16"