Skip to content

🧩 PPO/RLOO/OnlineDPO sequence generation: make deepsped 3 weight gathering optional#2557

Merged
qgallouedec merged 7 commits intohuggingface:mainfrom
dawidm:control-unwrap-for-generation
Jan 21, 2025
Merged

🧩 PPO/RLOO/OnlineDPO sequence generation: make deepsped 3 weight gathering optional#2557
qgallouedec merged 7 commits intohuggingface:mainfrom
dawidm:control-unwrap-for-generation

Conversation

@dawidm
Copy link
Copy Markdown
Contributor

@dawidm dawidm commented Jan 10, 2025

What does this PR do?

This is my approach to address #2529. Gathering was introduced in #1483 to speed up generation significantly, but it may cause OOM with bigger models (#2250) here:

with deepspeed.zero.GatheredParameters(model.parameters()):

I've introduced parameter ds3_gather_for_generation in trainer's configurations. With default True value, behavior is the same as before.

I ran official examples for PPO and RLOO with these changes. I only modified examples/accelerate_configs/deepspeed_zero3.yaml with offload_optimizer_device: cpu and num_processes: 4.

- Platform: Linux-5.15.0-122-generic-x86_64-with-glibc2.35
- Python version: 3.11.10
- PyTorch version: 2.5.1+cu124
- CUDA device(s): NVIDIA GeForce RTX 3090, NVIDIA GeForce RTX 3090, NVIDIA GeForce RTX 3090, NVIDIA GeForce RTX 3090
- Transformers version: 4.47.1
- Accelerate version: 0.34.2
- Accelerate config: not found
- Datasets version: 3.2.0
- HF Hub version: 0.27.1
- TRL version: 0.14.0.dev0+4c2ffdb
- bitsandbytes version: not installed
- DeepSpeed version: 0.15.4
- Diffusers version: not installed
- Liger-Kernel version: not installed
- LLM-Blender version: not installed
- OpenAI version: not installed
- PEFT version: not installed
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
    examples/scripts/ppo/ppo.py \
    --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
    --dataset_train_split descriptiveness \
    --output_dir models/minimal/ppo \
    --num_ppo_epochs 1 \
    --num_mini_batches 1 \
    --learning_rate 3e-6 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --total_episodes 200 \
    --model_name_or_path EleutherAI/pythia-12b \
    --sft_model_path EleutherAI/pythia-12b \
    --reward_model_path EleutherAI/pythia-70m \
    --local_rollout_forward_batch_size 16 \
    --missing_eos_penalty 1.0 \
    --ds3_gather_for_generation {True/False}
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
    examples/scripts/rloo/rloo.py \
    --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
    --dataset_train_split descriptiveness \
    --output_dir models/minimal/rloo \
    --rloo_k 2 \
    --num_ppo_epochs 1 \
    --num_mini_batches 1 \
    --learning_rate 3e-6 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --total_episodes 10000 \
    --model_name_or_path EleutherAI/pythia-12b \
    --sft_model_path EleutherAI/pythia-12b \
    --reward_model_path EleutherAI/pythia-70m \
    --local_rollout_forward_batch_size 16 \
    --missing_eos_penalty 1.0 \
    --ds3_gather_for_generation {True/False}

In both cases, for ds3_gather_for_generation True, I got expected OOM:

[rank0]: Traceback (most recent call last):                                                    
[rank0]:   File "/root/trl/examples/scripts/ppo/ppo.py", line 163, in <module>                                                                                                                
[rank0]:     trainer.train()                                                                                                                                                                  
[rank0]:   File "/root/trl/trl/trainer/ppo_trainer.py", line 417, in train                                                                                                                    
[rank0]:     with unwrap_model_for_generation(                                                 
[rank0]:   File "/opt/conda/lib/python3.11/contextlib.py", line 137, in __enter__
[rank0]:     return next(self.gen)
[rank0]:            ^^^^^^^^^^^^^^                                                             
[rank0]:   File "/root/trl/trl/models/utils.py", line 189, in unwrap_model_for_generation
[rank0]:     with deepspeed.zero.GatheredParameters(model.parameters()):         
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 2235, in __enter__
[rank0]:     self.params[0].all_gather(param_list=self.params)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1154, in all_gather
[rank0]:     return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                      
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)                                                                                                                                                  
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^                                                   
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1522, in _all_gather
[rank0]:     self._allgather_params_coalesced(all_gather_nonquantize_list, hierarchy, quantize=False)    
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1810, in _allgather_params_coalesced
[rank0]:     flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype,
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 150.00 MiB. GPU 0 has a total capacity of 23.68 GiB of which 121.81 MiB is free. Process 3411533 has 23.56 GiB memory i
n use. Of the allocated memory 22.29 GiB is allocated by PyTorch, and 886.03 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_
ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

With ds3_gather_for_generation False, sequences generation is successful. It's slow, as shown in #1483, but in private experiments I found it useful.

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Copy Markdown
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, nice addition!

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec changed the title PPO/RLOO/OnlineDPO sequence generation: make deepsped 3 weight gathering optional 🧩 PPO/RLOO/OnlineDPO sequence generation: make deepsped 3 weight gathering optional Jan 21, 2025
@qgallouedec qgallouedec merged commit d4222a1 into huggingface:main Jan 21, 2025
@yiyepiaoling0715
Copy link
Copy Markdown

how to change this code for bigger model like >14B model training?

yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
…ering optional (huggingface#2557)

* PPO/RLOO/OnlineDPO: add ds3_gather_for_generation argument to control weights gathering for generation

* code formatting

* rephrase and document

* more doc

* style [ci skip]

* Trigger CI

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Option to disable unwrapping model for generation in PPO/RLOO/OnlineDPO OOM when unwrap_model_for_generation

4 participants