diff --git a/docs/design-docs/fsdp2-parallel-plan.md b/docs/design-docs/fsdp2-parallel-plan.md new file mode 100644 index 0000000000..8318b3174c --- /dev/null +++ b/docs/design-docs/fsdp2-parallel-plan.md @@ -0,0 +1,30 @@ +# FSDP2 Parallel Plan + +This guide outlines the parallelization strategy for Fully Sharded Data Parallel version 2 (FSDP2) training in NeMo RL. + +## Fallback Priority + +NeMo RL supports three parallelization strategies, applied in the following order of fallback priority: + +### 1. Custom Parallel Plan + +Your user-defined custom parallel plans always take precedence when available. For detailed implementation and usage, refer to the [Custom Parallel Plan Example](#custom-parallel-plan-example). + +### 2. Optimized Parallel Plan + +Optimized parallel plans are available for specific model architectures. They may offer superior performance compared to Hugging Face's tensor parallel implementation. This approach is used if no custom parallel plan is specified and the model class supports optimized parallelization. + +### 3. Hugging Face Tensor Parallel Plan + +The Hugging Face tensor parallel plan is the default. It's available for most models via `._tp_plan` and is used when neither a custom nor an optimized parallel plan is available. + +## Custom Parallel Plan Example + +A custom parallel plan should be defined in a separate file, such as the example provided in `examples/custom_parallel.py`. + +To implement the custom parallel plan, either update the value of `custom_parallel_plan` in the `yaml` file directly, or pass the override via the command line. For example: + +```bash +uv run examples/run_grpo_math.py \ + policy.dtensor_cfg.custom_parallel_plan=examples.custom_parallel.custom_parallel_plan +``` diff --git a/docs/index.md b/docs/index.md index 4a0a5fcaa5..1e50854dfd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -61,4 +61,5 @@ design-docs/chat-datasets.md design-docs/generation.md design-docs/checkpointing.md design-docs/loss-functions.md +design-docs/fsdp2-parallel-plan.md ``` diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 1252adb131..1a185300c5 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -51,6 +51,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: false diff --git a/examples/configs/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/grpo-deepscaler-1.5b-8K.yaml index ecaca68d58..f37f3140c2 100644 --- a/examples/configs/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-8K.yaml @@ -50,7 +50,8 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 - + custom_parallel_plan: null + # makes the training sequence length divisible by the tensor parallel size # this is useful for sequence parallel training make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} diff --git a/examples/configs/grpo_deepscaler-1.5b-24K.yaml b/examples/configs/grpo_deepscaler-1.5b-24K.yaml index 9beb0c210b..a616ccfa6a 100644 --- a/examples/configs/grpo_deepscaler-1.5b-24K.yaml +++ b/examples/configs/grpo_deepscaler-1.5b-24K.yaml @@ -15,6 +15,7 @@ policy: sequence_parallel: true activation_checkpointing: true tensor_parallel_size: 4 + custom_parallel_plan: null optimizer: name: "torch.optim.AdamW" diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 4c7469d970..27f32fc432 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -50,6 +50,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null # dynamic_batching improves performance by ensuring logprob and training microbatches # have a sufficent number of tokens to maximize GPU utilization. Specifically, variable length diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp1-quick.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp1-quick.v2.yaml index c953e8ecd4..42b287565e 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp1-quick.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp1-quick.v2.yaml @@ -39,6 +39,8 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null + dynamic_batching: enabled: False diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml index 2ce7ec018a..563c9462d4 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml @@ -39,6 +39,8 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null + dynamic_batching: enabled: False diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml index c68e9af08c..74a0ecd3d1 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml @@ -39,6 +39,8 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 2 + custom_parallel_plan: null + dynamic_batching: enabled: False diff --git a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml index da1a95ac0a..3c4d31f324 100644 --- a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml @@ -40,6 +40,8 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null + dynamic_batching: enabled: False diff --git a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml index 330104005b..2eca1f773e 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} diff --git a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml index ff297d4a23..fbfad65fc5 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: true activation_checkpointing: true tensor_parallel_size: 8 + custom_parallel_plan: null dynamic_batching: # TODO: OOMs if enabled https://github.com/NVIDIA/NeMo-RL/issues/383 enabled: False diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml index ee26b2dfa3..d163f3e130 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml index 2ae26cda4a..19722c4746 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml index e609a45558..15e539dcf3 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: true activation_checkpointing: true tensor_parallel_size: 8 + custom_parallel_plan: null dynamic_batching: enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml index f8b37a53f8..e21526145d 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: true activation_checkpointing: true tensor_parallel_size: 8 + custom_parallel_plan: null dynamic_batching: enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml index 92ffcfef59..8b9a0f5c62 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: False make_sequence_length_divisible_by: 1 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml index bf196e3d37..e9d63dc158 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: true activation_checkpointing: false tensor_parallel_size: 4 + custom_parallel_plan: null dynamic_batching: enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml index 0099e9ebd7..49abe12153 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp1.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp1.v2.yaml index c089bcd4d0..d39bafbe91 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp1.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp1.v2.yaml @@ -31,6 +31,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: False make_sequence_length_divisible_by: 1 diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml index 86537fcb87..8f62399bde 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml @@ -31,6 +31,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: False make_sequence_length_divisible_by: 1 diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml index aec0d380c2..c4385da6e4 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml @@ -31,6 +31,7 @@ policy: sequence_parallel: true activation_checkpointing: false tensor_parallel_size: 2 + custom_parallel_plan: null dynamic_batching: enabled: False make_sequence_length_divisible_by: 2 diff --git a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml index bc9b6f6326..4d31d392e9 100644 --- a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml @@ -31,6 +31,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: False make_sequence_length_divisible_by: 1 diff --git a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml index 7d9cfe6eb1..6baa5f08f9 100644 --- a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml +++ b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml @@ -31,6 +31,7 @@ policy: sequence_parallel: true activation_checkpointing: true tensor_parallel_size: 8 + custom_parallel_plan: null dynamic_batching: enabled: False make_sequence_length_divisible_by: 8 diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 09d05ef89f..b8f01ce3e6 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -38,7 +38,8 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 - + custom_parallel_plan: null + dynamic_batching: enabled: false diff --git a/examples/configs/sft_openmathinstruct2.yaml b/examples/configs/sft_openmathinstruct2.yaml index ef853cc4cd..b885f7388b 100644 --- a/examples/configs/sft_openmathinstruct2.yaml +++ b/examples/configs/sft_openmathinstruct2.yaml @@ -34,6 +34,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 4 + custom_parallel_plan: null # makes the training sequence length divisible by the tensor parallel size # this is useful for sequence parallel training diff --git a/examples/custom_parallel.py b/examples/custom_parallel.py new file mode 100644 index 0000000000..647ddfc563 --- /dev/null +++ b/examples/custom_parallel.py @@ -0,0 +1,28 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel +from torch.distributed.tensor.placement_types import Replicate, Shard + +custom_parallel_plan = { + "model.embed_tokens": RowwiseParallel(input_layouts=Replicate()), + "model.layers.*.self_attn.q_proj": ColwiseParallel(), + "model.layers.*.self_attn.k_proj": ColwiseParallel(), + "model.layers.*.self_attn.v_proj": ColwiseParallel(), + "model.layers.*.self_attn.o_proj": RowwiseParallel(), + "model.layers.*.mlp.up_proj": ColwiseParallel(), + "model.layers.*.mlp.gate_proj": ColwiseParallel(), + "model.layers.*.mlp.down_proj": RowwiseParallel(), + "lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False), +} diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index 7e40e54e82..0d0c4ee413 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Union +from functools import lru_cache +from types import FunctionType +from typing import Callable, Optional, Union import torch from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -23,6 +25,7 @@ from torch.distributed.tensor import DTensor from torch.distributed.tensor.parallel import ( ColwiseParallel, + ParallelStyle, PrepareModuleInput, PrepareModuleOutput, RowwiseParallel, @@ -39,6 +42,7 @@ from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM from nemo_rl.distributed.model_utils import from_parallel_logits_to_logprobs +from nemo_rl.models.policy.utils import import_class_from_path class RotaryEmbedParallel(SequenceParallel): @@ -81,167 +85,107 @@ def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): def _parallelize_gemma3( model: Union[Gemma3ForCausalLM, Gemma3ForConditionalGeneration], - dp_mesh: DeviceMesh, - tp_mesh: DeviceMesh, - mp_policy: MixedPrecisionPolicy, - offload_policy: torch.distributed.fsdp.OffloadPolicy, sequence_parallel: bool = False, - activation_checkpointing: bool = False, ): """Parallelizes a Gemma3ForCausalLM model across data parallel dimensions. Tensor parallelism is not supported for Gemma3 models because of tied word embeddings. """ if isinstance(model, Gemma3ForConditionalGeneration): - layers = model.language_model.model.layers model_prefix = "language_model.model" - num_attention_heads = model.config.text_config.num_attention_heads - num_key_value_heads = model.config.text_config.num_key_value_heads else: - layers = model.model.layers model_prefix = "model" - num_attention_heads = model.config.num_attention_heads - num_key_value_heads = model.config.num_key_value_heads - if tp_mesh.size() > 1: - assert num_key_value_heads % tp_mesh.size() == 0, ( - f"num_key_value_heads ({num_key_value_heads}) must be divisible by TP size ({tp_mesh.size()})" - ) - assert num_attention_heads % tp_mesh.size() == 0, ( - f"num_attention_heads ({num_attention_heads}) must be divisible by TP size ({tp_mesh.size()})" - ) - - # For gemma3 models, we don't include the model.embed_tokens and lm_head in the - # parallelization plans because they have tied weights. - base_model_tp_plan = { - f"{model_prefix}.layers.*.self_attn.q_proj": ColwiseParallel(), - f"{model_prefix}.layers.*.self_attn.k_proj": ColwiseParallel(), - f"{model_prefix}.layers.*.self_attn.v_proj": ColwiseParallel(), - f"{model_prefix}.layers.*.self_attn.o_proj": RowwiseParallel(), - f"{model_prefix}.layers.*.mlp.up_proj": ColwiseParallel(), - f"{model_prefix}.layers.*.mlp.gate_proj": ColwiseParallel(), - f"{model_prefix}.layers.*.mlp.down_proj": RowwiseParallel(), - } - - base_model_sp_plan = { - f"{model_prefix}.embed_tokens": PrepareModuleOutput( - output_layouts=Replicate(), - desired_output_layouts=Shard(1), - use_local_output=False, - ), - f"{model_prefix}.rotary_emb": RotaryEmbedParallel(use_local_output=True), - f"{model_prefix}.rotary_emb_local": RotaryEmbedParallel( - use_local_output=True - ), - f"{model_prefix}.layers.*.input_layernorm": SequenceParallel(), - f"{model_prefix}.layers.*.self_attn.o_proj": RowwiseParallel( - output_layouts=Shard(1) - ), - f"{model_prefix}.layers.*.post_attention_layernorm": SequenceParallel(), - f"{model_prefix}.layers.*.pre_feedforward_layernorm": SequenceParallel(), - f"{model_prefix}.layers.*.mlp.down_proj": RowwiseParallel( - output_layouts=Shard(1) - ), - f"{model_prefix}.layers.*.post_feedforward_layernorm": SequenceParallel(), - f"{model_prefix}.norm": SequenceParallel(), - f"{model_prefix}.lm_head": PrepareModuleInput( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - use_local_output=True, - ), - } - - if sequence_parallel: - # Enable sequence parallelism only if TP size > 1 - base_model_tp_plan.update(base_model_sp_plan) - - parallelize_module(model, tp_mesh, base_model_tp_plan) - - if activation_checkpointing: - for i in range(len(layers)): - layers[i].mlp = checkpoint_wrapper(layers[i].mlp) - - for layer in layers: - fully_shard( - layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy - ) - - return fully_shard( - model, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy - ) + # For gemma3 models, we don't include the model.embed_tokens and lm_head in the + # parallelization plans because they have tied weights. + base_model_tp_plan = { + f"{model_prefix}.layers.*.self_attn.q_proj": ColwiseParallel(), + f"{model_prefix}.layers.*.self_attn.k_proj": ColwiseParallel(), + f"{model_prefix}.layers.*.self_attn.v_proj": ColwiseParallel(), + f"{model_prefix}.layers.*.self_attn.o_proj": RowwiseParallel(), + f"{model_prefix}.layers.*.mlp.up_proj": ColwiseParallel(), + f"{model_prefix}.layers.*.mlp.gate_proj": ColwiseParallel(), + f"{model_prefix}.layers.*.mlp.down_proj": RowwiseParallel(), + } + + base_model_sp_plan = { + f"{model_prefix}.embed_tokens": PrepareModuleOutput( + output_layouts=Replicate(), + desired_output_layouts=Shard(1), + use_local_output=False, + ), + f"{model_prefix}.rotary_emb": RotaryEmbedParallel(use_local_output=True), + f"{model_prefix}.rotary_emb_local": RotaryEmbedParallel(use_local_output=True), + f"{model_prefix}.layers.*.input_layernorm": SequenceParallel(), + f"{model_prefix}.layers.*.self_attn.o_proj": RowwiseParallel( + output_layouts=Shard(1) + ), + f"{model_prefix}.layers.*.post_attention_layernorm": SequenceParallel(), + f"{model_prefix}.layers.*.pre_feedforward_layernorm": SequenceParallel(), + f"{model_prefix}.layers.*.mlp.down_proj": RowwiseParallel( + output_layouts=Shard(1) + ), + f"{model_prefix}.layers.*.post_feedforward_layernorm": SequenceParallel(), + f"{model_prefix}.norm": SequenceParallel(), + f"{model_prefix}.lm_head": PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_output=True, + ), + } + + if sequence_parallel: + # Enable sequence parallelism only if TP size > 1 + base_model_tp_plan.update(base_model_sp_plan) + + return base_model_tp_plan def _parallelize_llama( model: LlamaForCausalLM, - dp_mesh: DeviceMesh, - tp_mesh: DeviceMesh, - mp_policy: MixedPrecisionPolicy, - offload_policy: torch.distributed.fsdp.OffloadPolicy, sequence_parallel: bool = False, - activation_checkpointing: bool = False, ): """Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions.""" - if tp_mesh.size() > 1: - assert not model.config.tie_word_embeddings, ( - "Tie word embeddings not supported when TP is enabled" - ) - - base_model_tp_plan = { - "model.embed_tokens": RowwiseParallel(input_layouts=Replicate()), - "model.layers.*.self_attn.q_proj": ColwiseParallel(), - "model.layers.*.self_attn.k_proj": ColwiseParallel(), - "model.layers.*.self_attn.v_proj": ColwiseParallel(), - "model.layers.*.self_attn.o_proj": RowwiseParallel(), - "model.layers.*.mlp.up_proj": ColwiseParallel(), - "model.layers.*.mlp.gate_proj": ColwiseParallel(), - "model.layers.*.mlp.down_proj": RowwiseParallel(), - "lm_head": ColwiseParallel( - output_layouts=Shard(-1), use_local_output=False - ), - } - - base_model_sp_plan = { - "model.embed_tokens": RowwiseParallel( - input_layouts=Replicate(), output_layouts=Shard(1) - ), - "model.norm": SequenceParallel(), - "model.layers.*.input_layernorm": SequenceParallel(), - "model.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)), - "model.layers.*.post_attention_layernorm": SequenceParallel(), - "model.layers.*.mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)), - "lm_head": ColwiseParallel( - input_layouts=Shard(1), output_layouts=Shard(-1), use_local_output=False - ), - } - - if sequence_parallel: - # Enable sequence parallelism only if TP size > 1 - base_model_tp_plan.update(base_model_sp_plan) - - parallelize_module(model, tp_mesh, base_model_tp_plan) - - if activation_checkpointing: - for i in range(len(model.model.layers)): - model.model.layers[i].mlp = checkpoint_wrapper(model.model.layers[i].mlp) # type: ignore - - for layer in model.model.layers: - fully_shard( - layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy - ) - - return fully_shard( - model, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy + assert not model.config.tie_word_embeddings, ( + "Tie word embeddings not supported when TP is enabled" ) + base_model_tp_plan = { + "model.embed_tokens": RowwiseParallel(input_layouts=Replicate()), + "model.layers.*.self_attn.q_proj": ColwiseParallel(), + "model.layers.*.self_attn.k_proj": ColwiseParallel(), + "model.layers.*.self_attn.v_proj": ColwiseParallel(), + "model.layers.*.self_attn.o_proj": RowwiseParallel(), + "model.layers.*.mlp.up_proj": ColwiseParallel(), + "model.layers.*.mlp.gate_proj": ColwiseParallel(), + "model.layers.*.mlp.down_proj": RowwiseParallel(), + "lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False), + } + + base_model_sp_plan = { + "model.embed_tokens": RowwiseParallel( + input_layouts=Replicate(), output_layouts=Shard(1) + ), + "model.norm": SequenceParallel(), + "model.layers.*.input_layernorm": SequenceParallel(), + "model.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)), + "model.layers.*.post_attention_layernorm": SequenceParallel(), + "model.layers.*.mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)), + "lm_head": ColwiseParallel( + input_layouts=Shard(1), output_layouts=Shard(-1), use_local_output=False + ), + } + + if sequence_parallel: + # Enable sequence parallelism only if TP size > 1 + base_model_tp_plan.update(base_model_sp_plan) + + return base_model_tp_plan + def _parallelize_qwen( model: Union[Qwen2ForCausalLM, Qwen3ForCausalLM], - dp_mesh: DeviceMesh, - tp_mesh: DeviceMesh, - mp_policy: MixedPrecisionPolicy, - offload_policy: torch.distributed.fsdp.OffloadPolicy, sequence_parallel: bool = False, - activation_checkpointing: bool = False, ): """Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions.""" @@ -262,80 +206,58 @@ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" ) - if tp_mesh.size() > 1: - assert not model.config.tie_word_embeddings, ( - "Tie word embeddings not supported when TP is enabled" - ) - if sequence_parallel: - base_model_tp_plan = { - "lm_head": ColwiseParallel( - input_layouts=Shard(1), - output_layouts=Shard(-1), - use_local_output=False, - ), - "model.embed_tokens": RowwiseParallel( - input_layouts=Replicate(), - output_layouts=Shard(1), - ), - "model.rotary_emb": RotaryEmbedParallel(), - "model.norm": SequenceParallel(), - "model.layers.*.input_layernorm": SequenceParallel(), - "model.layers.*.self_attn.q_proj": ColwiseParallel( - use_local_output=False - ), - "model.layers.*.self_attn.k_proj": ColwiseParallel( - use_local_output=False - ), - "model.layers.*.self_attn.v_proj": ColwiseParallel( - use_local_output=False - ), - "model.layers.*.self_attn.o_proj": RowwiseParallel( - output_layouts=Shard(1) - ), - "model.layers.*.self_attn.q_norm": Qwen3QKNorm(), - "model.layers.*.self_attn.k_norm": Qwen3QKNorm(), - "model.layers.*.post_attention_layernorm": SequenceParallel(), - "model.layers.*.mlp.up_proj": ColwiseParallel(), - "model.layers.*.mlp.gate_proj": ColwiseParallel(), - "model.layers.*.mlp.down_proj": RowwiseParallel( - output_layouts=Shard(1) - ), - } - - else: - base_model_tp_plan = { - "lm_head": ColwiseParallel( - output_layouts=Shard(-1), use_local_output=False - ), - "model.embed_tokens": RowwiseParallel( - input_layouts=Replicate(), - ), - "model.layers.*.self_attn.q_proj": ColwiseParallel(), - "model.layers.*.self_attn.k_proj": ColwiseParallel(), - "model.layers.*.self_attn.v_proj": ColwiseParallel(), - "model.layers.*.self_attn.o_proj": RowwiseParallel(), - "model.layers.*.mlp.up_proj": ColwiseParallel(), - "model.layers.*.mlp.gate_proj": ColwiseParallel(), - "model.layers.*.mlp.down_proj": RowwiseParallel(), - } - - parallelize_module(model, tp_mesh, base_model_tp_plan) - - if activation_checkpointing: - for i in range(len(model.model.layers)): - model.model.layers[i].mlp = checkpoint_wrapper(model.model.layers[i].mlp) # type: ignore + assert not model.config.tie_word_embeddings, ( + "Tie word embeddings not supported when TP is enabled" + ) + if sequence_parallel: + base_model_tp_plan = { + "lm_head": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1), + use_local_output=False, + ), + "model.embed_tokens": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "model.rotary_emb": RotaryEmbedParallel(), + "model.norm": SequenceParallel(), + "model.layers.*.input_layernorm": SequenceParallel(), + "model.layers.*.self_attn.q_proj": ColwiseParallel(use_local_output=False), + "model.layers.*.self_attn.k_proj": ColwiseParallel(use_local_output=False), + "model.layers.*.self_attn.v_proj": ColwiseParallel(use_local_output=False), + "model.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)), + "model.layers.*.self_attn.q_norm": Qwen3QKNorm(), + "model.layers.*.self_attn.k_norm": Qwen3QKNorm(), + "model.layers.*.post_attention_layernorm": SequenceParallel(), + "model.layers.*.mlp.up_proj": ColwiseParallel(), + "model.layers.*.mlp.gate_proj": ColwiseParallel(), + "model.layers.*.mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)), + } - for layer in model.model.layers: - fully_shard( - layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy - ) + else: + base_model_tp_plan = { + "lm_head": ColwiseParallel( + output_layouts=Shard(-1), use_local_output=False + ), + "model.embed_tokens": RowwiseParallel( + input_layouts=Replicate(), + ), + "model.layers.*.self_attn.q_proj": ColwiseParallel(), + "model.layers.*.self_attn.k_proj": ColwiseParallel(), + "model.layers.*.self_attn.v_proj": ColwiseParallel(), + "model.layers.*.self_attn.o_proj": RowwiseParallel(), + "model.layers.*.mlp.up_proj": ColwiseParallel(), + "model.layers.*.mlp.gate_proj": ColwiseParallel(), + "model.layers.*.mlp.down_proj": RowwiseParallel(), + } - return fully_shard( - model, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy - ) + return base_model_tp_plan -PARALLIZE_FUNCTIONS: dict[type[torch.nn.Module], Callable[..., torch.nn.Module]] = { +PARALLIZE_FUNCTIONS: dict[ + type[torch.nn.Module], Callable[..., dict[str, ParallelStyle]] +] = { Qwen2ForCausalLM: _parallelize_qwen, Qwen3ForCausalLM: _parallelize_qwen, LlamaForCausalLM: _parallelize_llama, @@ -346,25 +268,128 @@ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): } +@lru_cache +def translate_parallel_style(style: str): + """Translate parallel style str to parallel type. + + Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L547 + """ + assert isinstance(style, str), ( + f"parallel style type should be str, but got {type(style)}" + ) + + if style == "colwise": + return ColwiseParallel() + elif style == "rowwise": + return RowwiseParallel() + elif style == "colwise_rep": + return ColwiseParallel(output_layouts=Replicate()) + elif style == "rowwise_rep": + return RowwiseParallel(input_layouts=Replicate()) + elif style == "sequence_parallel": + return SequenceParallel() + else: + raise ValueError(f"Unknown parallel style: {style}") + + +def get_hf_tp_plan(model): + """Get the Hugging Face tensor parallel plan from the model. + + This function: + - Retrieves TP strategies from model class, instance, and inner model levels. + - Handles special cases for `embed_tokens` and `lm_head` for speed up. + - Converts string-based parallel styles to DTensor parallelization strategies. + + Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L532 + + Args: + model: A Hugging Face model instance + + Returns: + dict: A dictionary mapping model component paths to their parallelization strategies + + Raises: + AssertionError: If no TP plan is found + """ + model_cls = type(model) + if model_cls == Gemma3ForConditionalGeneration: + inner_model = model.language_model + model_prefix = "language_model" + else: + inner_model = model.model + model_prefix = "model" + + hf_tp_plan = {} + + # model_cls._tp_plan will override model_cls after xxxForCausalLM.post_init() (transformers==4.51.3) + if hasattr(model_cls, "_tp_plan") and model_cls._tp_plan is not None: + hf_tp_plan.update(model_cls._tp_plan) + + if hasattr(model, "_tp_plan") and model._tp_plan is not None: + hf_tp_plan.update(model._tp_plan) + + if hasattr(inner_model, "_tp_plan") and inner_model._tp_plan is not None: + hf_tp_plan.update( + {f"{model_prefix}.{k}": v for k, v in inner_model._tp_plan.items()} + ) + + assert len(hf_tp_plan) > 0, ( + f"Hugging Face tp plan is not supported for {model_cls}, please set dtensor_cfg.tensor_parallel_size to 1 or provide a custom_parallel_plan. " + "The usage example of custom_parallel_plan can refer to `docs/design-docs/fsdp2-parallel-plan.md`." + ) + + # hf tp plan not contain embed_tokens, we add it and set to rowwise_rep + if ( + f"{model_prefix}.embed_tokens" not in hf_tp_plan + and not model.config.tie_word_embeddings + ): + hf_tp_plan[f"{model_prefix}.embed_tokens"] = "rowwise_rep" + + for k, v in hf_tp_plan.items(): + # speed up the tp plan for lm_head + if ( + k == "lm_head" + and v == "colwise_rep" + and not model.config.tie_word_embeddings + ): + hf_tp_plan[k] = ColwiseParallel( + output_layouts=Shard(-1), use_local_output=False + ) + else: + hf_tp_plan[k] = translate_parallel_style(v) + + return hf_tp_plan + + def _parallelize_model( - model: Union[Qwen2ForCausalLM, LlamaForCausalLM], + model: Union[ + Qwen2ForCausalLM, + LlamaForCausalLM, + Gemma3ForCausalLM, + Gemma3ForConditionalGeneration, + ], dp_mesh: DeviceMesh, tp_mesh: DeviceMesh, param_dtype: torch.dtype, sequence_parallel: bool = False, activation_checkpointing: bool = False, cpu_offload: bool = False, + custom_parallel_plan: Optional[Union[dict, str]] = None, ): """Parallelize a model using DTensor. Args: - model (Union[Qwen2ForCausalLM, LlamaForCausalLM]): The model to parallelize. - dp_mesh (DeviceMesh): Device mesh for data parallelism. - tp_mesh (DeviceMesh): Device mesh for tensor parallelism. - param_dtype (torch.dtype): Data type for model parameters. - sequence_parallel (bool, optional): Whether to use sequence parallelism. Defaults to False. - activation_checkpointing (bool, optional): Whether to use activation checkpointing. Defaults to False. - cpu_offload (bool, optional): Whether to enable cpu offloading for FSDP. Defaults to False. + model: The model to parallelize. + dp_mesh: Device mesh for data parallelism. + tp_mesh: Device mesh for tensor parallelism. + param_dtype: Data type for model parameters. + sequence_parallel: Whether to use sequence parallelism. Defaults to False. + activation_checkpointing: Whether to use activation checkpointing. Defaults to False. + cpu_offload: Whether to enable cpu offloading for FSDP. Defaults to False. + custom_parallel_plan: Custom parallel plan for the model. Defaults to None. + If it's a dict, it will be used as the parallel plan directly. + If it's a string, it must be a path that points to a dict or a function that returns a dict. + The usage example can refer to `docs/design-docs/fsdp2-parallel-plan.md`. Returns: The parallelized model. @@ -372,31 +397,96 @@ def _parallelize_model( Raises: ValueError: If the model type is not supported for parallelization. """ + model_cls = type(model) + if model_cls == Gemma3ForConditionalGeneration: + layers: torch.nn.ModuleList = model.language_model.model.layers # type: ignore + num_attention_heads = model.config.text_config.num_attention_heads + num_key_value_heads = model.config.text_config.num_key_value_heads + else: + layers: torch.nn.ModuleList = model.model.layers # type: ignore + num_attention_heads = model.config.num_attention_heads + num_key_value_heads = model.config.num_key_value_heads + + if tp_mesh.size() > 1: + assert num_key_value_heads % tp_mesh.size() == 0, ( + f"num_key_value_heads ({num_key_value_heads}) must be divisible by TP size ({tp_mesh.size()})" + ) + assert num_attention_heads % tp_mesh.size() == 0, ( + f"num_attention_heads ({num_attention_heads}) must be divisible by TP size ({tp_mesh.size()})" + ) + + # first use user's custom parallel plan + if custom_parallel_plan is not None: + if isinstance(custom_parallel_plan, dict): + model_parallel_plan = custom_parallel_plan + else: + try: + model_parallel_plan = import_class_from_path(custom_parallel_plan) + if isinstance(model_parallel_plan, FunctionType): + model_parallel_plan = model_parallel_plan() + assert isinstance(model_parallel_plan, dict) + except: + raise ValueError( + f"Your custom parallel plan is `{custom_parallel_plan}` which is not valid. Please ensure it is one of the following:\n" + "1. A dictionary\n" + "2. A path to a dictionary\n" + "3. A path to a function that returns a dictionary" + ) + print("Using custom parallel plan.") + + # second use our optimized parallel plan + elif model_cls in PARALLIZE_FUNCTIONS: + # try to use our optimized parallel plan + try: + func = PARALLIZE_FUNCTIONS[model_cls] + model_parallel_plan = func(model, sequence_parallel) + print("Using optimized parallel plan.") + # fall back to the HF tp plan + except Exception as e: + print( + f"Optimized parallel plan is not available: {e}. Falling back to the HF tp plan." + ) + assert not sequence_parallel, ( + "sequence_parallel is not support in HF tp plan." + ) + model_parallel_plan = get_hf_tp_plan(model) + + # final use the default HF tp plan + else: + # optimized parallel plan is not support for the model class + print( + f"Optimized parallel plan is not support for {model_cls}. Falling back to the HF tp plan." + ) + assert not sequence_parallel, ( + "sequence_parallel is not support in HF tp plan." + ) + model_parallel_plan = get_hf_tp_plan(model) + + parallelize_module(model, tp_mesh, model_parallel_plan) + + if activation_checkpointing: + for i in range(len(layers)): + layers[i].mlp = checkpoint_wrapper(layers[i].mlp) # type: ignore + mp_policy = MixedPrecisionPolicy( param_dtype=param_dtype, reduce_dtype=torch.float32, output_dtype=torch.float32, ) + offload_policy = ( CPUOffloadPolicy(pin_memory=False) if cpu_offload else torch.distributed.fsdp.OffloadPolicy ) - model_cls = type(model) - if model_cls not in PARALLIZE_FUNCTIONS: - raise ValueError(f"Model {model_cls} not supported as part of dtensor") - - func = PARALLIZE_FUNCTIONS[type(model)] - - return func( - model, - dp_mesh, - tp_mesh, - mp_policy, - offload_policy, - sequence_parallel, - activation_checkpointing, + for layer in layers: + fully_shard( + layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy + ) + + return fully_shard( + model, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy ) diff --git a/nemo_rl/models/huggingface/common.py b/nemo_rl/models/huggingface/common.py index 10fa3f4cfa..df913f95b4 100644 --- a/nemo_rl/models/huggingface/common.py +++ b/nemo_rl/models/huggingface/common.py @@ -39,15 +39,17 @@ class ModelFlag(Enum): def matches(self, model_name: str) -> bool: match self: case ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK: - return is_gemma3_model(model_name) + return is_gemma_model(model_name) case ModelFlag.VLLM_LOAD_FORMAT_AUTO: - return is_gemma3_model(model_name) + return is_gemma_model(model_name) case _: raise ValueError(f"Unknown ModelFlag: {self}") -def is_gemma3_model(model_name: str) -> bool: +def is_gemma_model(model_name: str) -> bool: hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - return hasattr(hf_config, "model_type") and ( - hf_config.model_type == "gemma3" or hf_config.model_type == "gemma3_text" - ) + return hasattr(hf_config, "model_type") and hf_config.model_type in [ + "gemma2", + "gemma3", + "gemma3_text", + ] diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index e0f7916835..6e1bad7048 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -23,6 +23,7 @@ class DTensorConfig(TypedDict): sequence_parallel: bool activation_checkpointing: bool tensor_parallel_size: int + custom_parallel_plan: str class TokenizerConfig(TypedDict): diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 0108d7e8b6..91d52b77e9 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -192,6 +192,7 @@ def __init__( activation_checkpointing=self.cfg["dtensor_cfg"][ "activation_checkpointing" ], + custom_parallel_plan=self.cfg["dtensor_cfg"]["custom_parallel_plan"], ) if self.cpu_offload: @@ -660,6 +661,9 @@ def _add_noise_to_weights(self) -> None: p.data.add_(noise) # Add noise in-place torch.cuda.synchronize() + def return_state_dict(self): + return self.model.state_dict() + def report_device_id(self) -> str: """Report the UUID of the current CUDA device using NVML. diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index c31ef8f65d..11192a3ad6 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -85,6 +85,7 @@ def get_basic_hf_test_config(enable_dtensor: bool = False) -> PolicyConfig: "sequence_parallel": False, "activation_checkpointing": False, "tensor_parallel_size": 1, + "custom_parallel_plan": None, }, "dynamic_batching": { "enabled": enable_dtensor, # Dynamic batching is only supported with DTensor diff --git a/tests/unit/models/huggingface/test_common.py b/tests/unit/models/huggingface/test_common.py index 74fb2f0848..faf06fbdb7 100644 --- a/tests/unit/models/huggingface/test_common.py +++ b/tests/unit/models/huggingface/test_common.py @@ -14,12 +14,18 @@ import pytest -from nemo_rl.models.huggingface.common import ModelFlag, is_gemma3_model +from nemo_rl.models.huggingface.common import ModelFlag, is_gemma_model @pytest.mark.parametrize( "model_name", [ + "google/gemma-2-2b", + "google/gemma-2-9b", + "google/gemma-2-27b", + "google/gemma-2-2b-it", + "google/gemma-2-9b-it", + "google/gemma-2-27b-it", "google/gemma-3-1b-pt", "google/gemma-3-4b-pt", "google/gemma-3-12b-pt", @@ -30,8 +36,8 @@ "google/gemma-3-27b-it", ], ) -def test_gemma3_models(model_name): - assert is_gemma3_model(model_name) +def test_gemma_models(model_name): + assert is_gemma_model(model_name) assert ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK.matches(model_name) assert ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(model_name) @@ -44,7 +50,7 @@ def test_gemma3_models(model_name): "Qwen/Qwen2.5-3B-Instruct", ], ) -def test_non_gemma3_models(model_name): - assert not is_gemma3_model(model_name) +def test_non_gemma_models(model_name): + assert not is_gemma_model(model_name) assert not ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK.matches(model_name) assert not ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(model_name) diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index 8ea07c7f35..c50efca2d4 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -30,7 +30,6 @@ from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.models.generation import configure_generation_config from nemo_rl.models.policy import PolicyConfig -from nemo_rl.models.policy.dtensor_policy_worker import DTensorPolicyWorker from nemo_rl.models.policy.hf_policy import HfPolicy from tests.unit.conftest import TEST_ASSETS from tests.unit.test_utils import SimpleLoss @@ -42,6 +41,7 @@ def create_test_config( sequence_parallel: bool = False, cpu_offload: bool = False, activation_checkpointing: bool = False, + custom_parallel_plan: str = None, ) -> PolicyConfig: return { "model_name": model_name, @@ -67,6 +67,7 @@ def create_test_config( "sequence_parallel": sequence_parallel, "activation_checkpointing": activation_checkpointing, "tensor_parallel_size": tp, + "custom_parallel_plan": custom_parallel_plan, }, "dynamic_batching": { "enabled": True, @@ -461,25 +462,43 @@ def test_dtensor_worker_logprob_tp2_matches_no_tp(logprob_setup): ) -def test_dtensor_fails_with_tp_and_tied_model(mock_2gpu_distributed_env): - """Test that DTensor fails with a tp > 1 and a tied model.""" +def test_dtensor_tp_and_tied_model_with_custom_parallel_plan(two_gpu_virtual_cluster): + """Test that DTensor with a tp > 1 and a tied model with a custom parallel plan works.""" + from torch.distributed.tensor.parallel import ColwiseParallel + from torch.distributed.tensor.placement_types import Replicate + + custom_parallel_plan = {"lm_head": ColwiseParallel(output_layouts=Replicate())} config = create_test_config( model_name=TEST_ASSETS.TINY_LLAMA_TIED_MODEL_PATH, tp=2, cpu_offload=False, sequence_parallel=False, activation_checkpointing=False, + custom_parallel_plan=custom_parallel_plan, ) tokenizer = get_tokenizer(config["tokenizer"]) - with pytest.raises( - AssertionError, match="Tie word embeddings not supported when TP is enabled" - ): - DTensorPolicyWorker.__ray_actor_class__( - config=config, - tokenizer=tokenizer, - init_optimizer=False, - init_reference_model=False, - ) + + policy = HfPolicy( + tokenizer=tokenizer, + config=config, + init_optimizer=False, + init_reference_model=False, + cluster=two_gpu_virtual_cluster, + ) + + # Verify that the model is parallelized as expected + state_dict = ray.get(policy.worker_group.workers[0].return_state_dict.remote()) + total_shape = state_dict["lm_head.weight"].shape + sharded_shape = state_dict["lm_head.weight"].to_local().shape + assert total_shape[0] == sharded_shape[0] * 2, ( + "lm_head.weight should be sharded across 2 GPUs" + ) + assert total_shape[1] == sharded_shape[1], ( + "lm_head.weight should have the same number of columns" + ) + + # Clean up + policy.shutdown() @pytest.mark.timeout(180) diff --git a/tests/unit/models/policy/test_fsdp1_worker.py b/tests/unit/models/policy/test_fsdp1_worker.py index 591acd8749..3ec19f5148 100644 --- a/tests/unit/models/policy/test_fsdp1_worker.py +++ b/tests/unit/models/policy/test_fsdp1_worker.py @@ -58,6 +58,7 @@ "sequence_parallel": False, "activation_checkpointing": False, "tensor_parallel_size": 1, + "custom_parallel_plan": None, }, "dynamic_batching": { "enabled": False, diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py index a53d7dd44a..68a1f3e217 100755 --- a/tests/unit/utils/test_native_checkpoint.py +++ b/tests/unit/utils/test_native_checkpoint.py @@ -59,6 +59,7 @@ "sequence_parallel": False, "activation_checkpointing": False, "tensor_parallel_size": 1, + "custom_parallel_plan": None, }, "dynamic_batching": { "enabled": False,