Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/design-docs/fsdp2-parallel-plan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# FSDP2 Parallel Plan

This guide outlines the parallelization strategy for Fully Sharded Data Parallel version 2 (FSDP2) training in NeMo RL.

## Fallback Priority

NeMo RL supports three parallelization strategies, applied in the following order of fallback priority:

### 1. Custom Parallel Plan

Your user-defined custom parallel plans always take precedence when available. For detailed implementation and usage, refer to the [Custom Parallel Plan Example](#custom-parallel-plan-example).

### 2. Optimized Parallel Plan

Optimized parallel plans are available for specific model architectures. They may offer superior performance compared to Hugging Face's tensor parallel implementation. This approach is used if no custom parallel plan is specified and the model class supports optimized parallelization.

### 3. Hugging Face Tensor Parallel Plan

The Hugging Face tensor parallel plan is the default. It's available for most models via `._tp_plan` and is used when neither a custom nor an optimized parallel plan is available.

## Custom Parallel Plan Example

A custom parallel plan should be defined in a separate file, such as the example provided in `examples/custom_parallel.py`.

To implement the custom parallel plan, either update the value of `custom_parallel_plan` in the `yaml` file directly, or pass the override via the command line. For example:

```bash
uv run examples/run_grpo_math.py \
policy.dtensor_cfg.custom_parallel_plan=examples.custom_parallel.custom_parallel_plan
```
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,5 @@ design-docs/chat-datasets.md
design-docs/generation.md
design-docs/checkpointing.md
design-docs/loss-functions.md
design-docs/fsdp2-parallel-plan.md
```
1 change: 1 addition & 0 deletions examples/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
custom_parallel_plan: null

dynamic_batching:
enabled: false
Expand Down
3 changes: 2 additions & 1 deletion examples/configs/grpo-deepscaler-1.5b-8K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1

custom_parallel_plan: null

# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_deepscaler-1.5b-24K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ policy:
sequence_parallel: true
activation_checkpointing: true
tensor_parallel_size: 4
custom_parallel_plan: null

optimizer:
name: "torch.optim.AdamW"
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
custom_parallel_plan: null

# dynamic_batching improves performance by ensuring logprob and training microbatches
# have a sufficent number of tokens to maximize GPU utilization. Specifically, variable length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
custom_parallel_plan: null

dynamic_batching:
enabled: False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
custom_parallel_plan: null

dynamic_batching:
enabled: False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 2
custom_parallel_plan: null

dynamic_batching:
enabled: False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
custom_parallel_plan: null

dynamic_batching:
enabled: False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
custom_parallel_plan: null
dynamic_batching:
enabled: True
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
sequence_parallel: true
activation_checkpointing: true
tensor_parallel_size: 8
custom_parallel_plan: null
dynamic_batching:
# TODO: OOMs if enabled https://github.com/NVIDIA/NeMo-RL/issues/383
enabled: False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
custom_parallel_plan: null
dynamic_batching:
enabled: True
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
custom_parallel_plan: null
dynamic_batching:
enabled: True
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
sequence_parallel: true
activation_checkpointing: true
tensor_parallel_size: 8
custom_parallel_plan: null
dynamic_batching:
enabled: True
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
sequence_parallel: true
activation_checkpointing: true
tensor_parallel_size: 8
custom_parallel_plan: null
dynamic_batching:
enabled: True
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
custom_parallel_plan: null
dynamic_batching:
enabled: False
make_sequence_length_divisible_by: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
sequence_parallel: true
activation_checkpointing: false
tensor_parallel_size: 4
custom_parallel_plan: null
dynamic_batching:
enabled: True
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
custom_parallel_plan: null
dynamic_batching:
enabled: True
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
custom_parallel_plan: null
dynamic_batching:
enabled: False
make_sequence_length_divisible_by: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
custom_parallel_plan: null
dynamic_batching:
enabled: False
make_sequence_length_divisible_by: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ policy:
sequence_parallel: true
activation_checkpointing: false
tensor_parallel_size: 2
custom_parallel_plan: null
dynamic_batching:
enabled: False
make_sequence_length_divisible_by: 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
custom_parallel_plan: null
dynamic_batching:
enabled: False
make_sequence_length_divisible_by: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ policy:
sequence_parallel: true
activation_checkpointing: true
tensor_parallel_size: 8
custom_parallel_plan: null
dynamic_batching:
enabled: False
make_sequence_length_divisible_by: 8
Expand Down
3 changes: 2 additions & 1 deletion examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1

custom_parallel_plan: null

dynamic_batching:
enabled: false

Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft_openmathinstruct2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ policy:
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 4
custom_parallel_plan: null

# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
Expand Down
28 changes: 28 additions & 0 deletions examples/custom_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.placement_types import Replicate, Shard

custom_parallel_plan = {
"model.embed_tokens": RowwiseParallel(input_layouts=Replicate()),
"model.layers.*.self_attn.q_proj": ColwiseParallel(),
"model.layers.*.self_attn.k_proj": ColwiseParallel(),
"model.layers.*.self_attn.v_proj": ColwiseParallel(),
"model.layers.*.self_attn.o_proj": RowwiseParallel(),
"model.layers.*.mlp.up_proj": ColwiseParallel(),
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
"model.layers.*.mlp.down_proj": RowwiseParallel(),
"lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
}
Loading
Loading