From d58103dd68f7cbb1b58368cd11ea106e44ca8d69 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 13 May 2025 07:43:04 +0000 Subject: [PATCH 1/9] support hf tp plan, add custom_parallel_plan param Signed-off-by: Yuki Huang --- examples/configs/dpo.yaml | 1 + examples/configs/grpo_math_1B.yaml | 1 + examples/configs/sft.yaml | 3 +- nemo_rl/models/dtensor/parallelize.py | 430 ++++++++++-------- nemo_rl/models/policy/__init__.py | 1 + .../models/policy/dtensor_policy_worker.py | 1 + 6 files changed, 236 insertions(+), 201 deletions(-) diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 1252adb131..1a185300c5 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -51,6 +51,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: false diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 4c7469d970..27f32fc432 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -50,6 +50,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null # dynamic_batching improves performance by ensuring logprob and training microbatches # have a sufficent number of tokens to maximize GPU utilization. Specifically, variable length diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 09d05ef89f..b8f01ce3e6 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -38,7 +38,8 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 - + custom_parallel_plan: null + dynamic_batching: enabled: false diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index 7e40e54e82..0829b20b90 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -93,83 +93,65 @@ def _parallelize_gemma3( Tensor parallelism is not supported for Gemma3 models because of tied word embeddings. """ if isinstance(model, Gemma3ForConditionalGeneration): - layers = model.language_model.model.layers model_prefix = "language_model.model" num_attention_heads = model.config.text_config.num_attention_heads num_key_value_heads = model.config.text_config.num_key_value_heads else: - layers = model.model.layers model_prefix = "model" num_attention_heads = model.config.num_attention_heads num_key_value_heads = model.config.num_key_value_heads - if tp_mesh.size() > 1: - assert num_key_value_heads % tp_mesh.size() == 0, ( - f"num_key_value_heads ({num_key_value_heads}) must be divisible by TP size ({tp_mesh.size()})" - ) - assert num_attention_heads % tp_mesh.size() == 0, ( - f"num_attention_heads ({num_attention_heads}) must be divisible by TP size ({tp_mesh.size()})" - ) - - # For gemma3 models, we don't include the model.embed_tokens and lm_head in the - # parallelization plans because they have tied weights. - base_model_tp_plan = { - f"{model_prefix}.layers.*.self_attn.q_proj": ColwiseParallel(), - f"{model_prefix}.layers.*.self_attn.k_proj": ColwiseParallel(), - f"{model_prefix}.layers.*.self_attn.v_proj": ColwiseParallel(), - f"{model_prefix}.layers.*.self_attn.o_proj": RowwiseParallel(), - f"{model_prefix}.layers.*.mlp.up_proj": ColwiseParallel(), - f"{model_prefix}.layers.*.mlp.gate_proj": ColwiseParallel(), - f"{model_prefix}.layers.*.mlp.down_proj": RowwiseParallel(), - } - - base_model_sp_plan = { - f"{model_prefix}.embed_tokens": PrepareModuleOutput( - output_layouts=Replicate(), - desired_output_layouts=Shard(1), - use_local_output=False, - ), - f"{model_prefix}.rotary_emb": RotaryEmbedParallel(use_local_output=True), - f"{model_prefix}.rotary_emb_local": RotaryEmbedParallel( - use_local_output=True - ), - f"{model_prefix}.layers.*.input_layernorm": SequenceParallel(), - f"{model_prefix}.layers.*.self_attn.o_proj": RowwiseParallel( - output_layouts=Shard(1) - ), - f"{model_prefix}.layers.*.post_attention_layernorm": SequenceParallel(), - f"{model_prefix}.layers.*.pre_feedforward_layernorm": SequenceParallel(), - f"{model_prefix}.layers.*.mlp.down_proj": RowwiseParallel( - output_layouts=Shard(1) - ), - f"{model_prefix}.layers.*.post_feedforward_layernorm": SequenceParallel(), - f"{model_prefix}.norm": SequenceParallel(), - f"{model_prefix}.lm_head": PrepareModuleInput( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - use_local_output=True, - ), - } - - if sequence_parallel: - # Enable sequence parallelism only if TP size > 1 - base_model_tp_plan.update(base_model_sp_plan) - - parallelize_module(model, tp_mesh, base_model_tp_plan) - - if activation_checkpointing: - for i in range(len(layers)): - layers[i].mlp = checkpoint_wrapper(layers[i].mlp) - - for layer in layers: - fully_shard( - layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy - ) - - return fully_shard( - model, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy + assert num_key_value_heads % tp_mesh.size() == 0, ( + f"num_key_value_heads ({num_key_value_heads}) must be divisible by TP size ({tp_mesh.size()})" + ) + assert num_attention_heads % tp_mesh.size() == 0, ( + f"num_attention_heads ({num_attention_heads}) must be divisible by TP size ({tp_mesh.size()})" ) + # For gemma3 models, we don't include the model.embed_tokens and lm_head in the + # parallelization plans because they have tied weights. + base_model_tp_plan = { + f"{model_prefix}.layers.*.self_attn.q_proj": ColwiseParallel(), + f"{model_prefix}.layers.*.self_attn.k_proj": ColwiseParallel(), + f"{model_prefix}.layers.*.self_attn.v_proj": ColwiseParallel(), + f"{model_prefix}.layers.*.self_attn.o_proj": RowwiseParallel(), + f"{model_prefix}.layers.*.mlp.up_proj": ColwiseParallel(), + f"{model_prefix}.layers.*.mlp.gate_proj": ColwiseParallel(), + f"{model_prefix}.layers.*.mlp.down_proj": RowwiseParallel(), + } + + base_model_sp_plan = { + f"{model_prefix}.embed_tokens": PrepareModuleOutput( + output_layouts=Replicate(), + desired_output_layouts=Shard(1), + use_local_output=False, + ), + f"{model_prefix}.rotary_emb": RotaryEmbedParallel(use_local_output=True), + f"{model_prefix}.rotary_emb_local": RotaryEmbedParallel(use_local_output=True), + f"{model_prefix}.layers.*.input_layernorm": SequenceParallel(), + f"{model_prefix}.layers.*.self_attn.o_proj": RowwiseParallel( + output_layouts=Shard(1) + ), + f"{model_prefix}.layers.*.post_attention_layernorm": SequenceParallel(), + f"{model_prefix}.layers.*.pre_feedforward_layernorm": SequenceParallel(), + f"{model_prefix}.layers.*.mlp.down_proj": RowwiseParallel( + output_layouts=Shard(1) + ), + f"{model_prefix}.layers.*.post_feedforward_layernorm": SequenceParallel(), + f"{model_prefix}.norm": SequenceParallel(), + f"{model_prefix}.lm_head": PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_output=True, + ), + } + + if sequence_parallel: + # Enable sequence parallelism only if TP size > 1 + base_model_tp_plan.update(base_model_sp_plan) + + return base_model_tp_plan + def _parallelize_llama( model: LlamaForCausalLM, @@ -181,58 +163,42 @@ def _parallelize_llama( activation_checkpointing: bool = False, ): """Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions.""" - if tp_mesh.size() > 1: - assert not model.config.tie_word_embeddings, ( - "Tie word embeddings not supported when TP is enabled" - ) - - base_model_tp_plan = { - "model.embed_tokens": RowwiseParallel(input_layouts=Replicate()), - "model.layers.*.self_attn.q_proj": ColwiseParallel(), - "model.layers.*.self_attn.k_proj": ColwiseParallel(), - "model.layers.*.self_attn.v_proj": ColwiseParallel(), - "model.layers.*.self_attn.o_proj": RowwiseParallel(), - "model.layers.*.mlp.up_proj": ColwiseParallel(), - "model.layers.*.mlp.gate_proj": ColwiseParallel(), - "model.layers.*.mlp.down_proj": RowwiseParallel(), - "lm_head": ColwiseParallel( - output_layouts=Shard(-1), use_local_output=False - ), - } - - base_model_sp_plan = { - "model.embed_tokens": RowwiseParallel( - input_layouts=Replicate(), output_layouts=Shard(1) - ), - "model.norm": SequenceParallel(), - "model.layers.*.input_layernorm": SequenceParallel(), - "model.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)), - "model.layers.*.post_attention_layernorm": SequenceParallel(), - "model.layers.*.mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)), - "lm_head": ColwiseParallel( - input_layouts=Shard(1), output_layouts=Shard(-1), use_local_output=False - ), - } - - if sequence_parallel: - # Enable sequence parallelism only if TP size > 1 - base_model_tp_plan.update(base_model_sp_plan) - - parallelize_module(model, tp_mesh, base_model_tp_plan) - - if activation_checkpointing: - for i in range(len(model.model.layers)): - model.model.layers[i].mlp = checkpoint_wrapper(model.model.layers[i].mlp) # type: ignore - - for layer in model.model.layers: - fully_shard( - layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy - ) - - return fully_shard( - model, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy + assert not model.config.tie_word_embeddings, ( + "Tie word embeddings not supported when TP is enabled" ) + base_model_tp_plan = { + "model.embed_tokens": RowwiseParallel(input_layouts=Replicate()), + "model.layers.*.self_attn.q_proj": ColwiseParallel(), + "model.layers.*.self_attn.k_proj": ColwiseParallel(), + "model.layers.*.self_attn.v_proj": ColwiseParallel(), + "model.layers.*.self_attn.o_proj": RowwiseParallel(), + "model.layers.*.mlp.up_proj": ColwiseParallel(), + "model.layers.*.mlp.gate_proj": ColwiseParallel(), + "model.layers.*.mlp.down_proj": RowwiseParallel(), + "lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False), + } + + base_model_sp_plan = { + "model.embed_tokens": RowwiseParallel( + input_layouts=Replicate(), output_layouts=Shard(1) + ), + "model.norm": SequenceParallel(), + "model.layers.*.input_layernorm": SequenceParallel(), + "model.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)), + "model.layers.*.post_attention_layernorm": SequenceParallel(), + "model.layers.*.mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)), + "lm_head": ColwiseParallel( + input_layouts=Shard(1), output_layouts=Shard(-1), use_local_output=False + ), + } + + if sequence_parallel: + # Enable sequence parallelism only if TP size > 1 + base_model_tp_plan.update(base_model_sp_plan) + + return base_model_tp_plan + def _parallelize_qwen( model: Union[Qwen2ForCausalLM, Qwen3ForCausalLM], @@ -262,77 +228,53 @@ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" ) - if tp_mesh.size() > 1: - assert not model.config.tie_word_embeddings, ( - "Tie word embeddings not supported when TP is enabled" - ) - if sequence_parallel: - base_model_tp_plan = { - "lm_head": ColwiseParallel( - input_layouts=Shard(1), - output_layouts=Shard(-1), - use_local_output=False, - ), - "model.embed_tokens": RowwiseParallel( - input_layouts=Replicate(), - output_layouts=Shard(1), - ), - "model.rotary_emb": RotaryEmbedParallel(), - "model.norm": SequenceParallel(), - "model.layers.*.input_layernorm": SequenceParallel(), - "model.layers.*.self_attn.q_proj": ColwiseParallel( - use_local_output=False - ), - "model.layers.*.self_attn.k_proj": ColwiseParallel( - use_local_output=False - ), - "model.layers.*.self_attn.v_proj": ColwiseParallel( - use_local_output=False - ), - "model.layers.*.self_attn.o_proj": RowwiseParallel( - output_layouts=Shard(1) - ), - "model.layers.*.self_attn.q_norm": Qwen3QKNorm(), - "model.layers.*.self_attn.k_norm": Qwen3QKNorm(), - "model.layers.*.post_attention_layernorm": SequenceParallel(), - "model.layers.*.mlp.up_proj": ColwiseParallel(), - "model.layers.*.mlp.gate_proj": ColwiseParallel(), - "model.layers.*.mlp.down_proj": RowwiseParallel( - output_layouts=Shard(1) - ), - } - - else: - base_model_tp_plan = { - "lm_head": ColwiseParallel( - output_layouts=Shard(-1), use_local_output=False - ), - "model.embed_tokens": RowwiseParallel( - input_layouts=Replicate(), - ), - "model.layers.*.self_attn.q_proj": ColwiseParallel(), - "model.layers.*.self_attn.k_proj": ColwiseParallel(), - "model.layers.*.self_attn.v_proj": ColwiseParallel(), - "model.layers.*.self_attn.o_proj": RowwiseParallel(), - "model.layers.*.mlp.up_proj": ColwiseParallel(), - "model.layers.*.mlp.gate_proj": ColwiseParallel(), - "model.layers.*.mlp.down_proj": RowwiseParallel(), - } - - parallelize_module(model, tp_mesh, base_model_tp_plan) - - if activation_checkpointing: - for i in range(len(model.model.layers)): - model.model.layers[i].mlp = checkpoint_wrapper(model.model.layers[i].mlp) # type: ignore + assert not model.config.tie_word_embeddings, ( + "Tie word embeddings not supported when TP is enabled" + ) + if sequence_parallel: + base_model_tp_plan = { + "lm_head": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1), + use_local_output=False, + ), + "model.embed_tokens": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "model.rotary_emb": RotaryEmbedParallel(), + "model.norm": SequenceParallel(), + "model.layers.*.input_layernorm": SequenceParallel(), + "model.layers.*.self_attn.q_proj": ColwiseParallel(use_local_output=False), + "model.layers.*.self_attn.k_proj": ColwiseParallel(use_local_output=False), + "model.layers.*.self_attn.v_proj": ColwiseParallel(use_local_output=False), + "model.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)), + "model.layers.*.self_attn.q_norm": Qwen3QKNorm(), + "model.layers.*.self_attn.k_norm": Qwen3QKNorm(), + "model.layers.*.post_attention_layernorm": SequenceParallel(), + "model.layers.*.mlp.up_proj": ColwiseParallel(), + "model.layers.*.mlp.gate_proj": ColwiseParallel(), + "model.layers.*.mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)), + } - for layer in model.model.layers: - fully_shard( - layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy - ) + else: + base_model_tp_plan = { + "lm_head": ColwiseParallel( + output_layouts=Shard(-1), use_local_output=False + ), + "model.embed_tokens": RowwiseParallel( + input_layouts=Replicate(), + ), + "model.layers.*.self_attn.q_proj": ColwiseParallel(), + "model.layers.*.self_attn.k_proj": ColwiseParallel(), + "model.layers.*.self_attn.v_proj": ColwiseParallel(), + "model.layers.*.self_attn.o_proj": RowwiseParallel(), + "model.layers.*.mlp.up_proj": ColwiseParallel(), + "model.layers.*.mlp.gate_proj": ColwiseParallel(), + "model.layers.*.mlp.down_proj": RowwiseParallel(), + } - return fully_shard( - model, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy - ) + return base_model_tp_plan PARALLIZE_FUNCTIONS: dict[type[torch.nn.Module], Callable[..., torch.nn.Module]] = { @@ -346,6 +288,45 @@ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): } +def translate_parallel_style(style: str): + """Translate parallel style str to parallel type""" + assert isinstance(style, str), ( + f"parallel style type should be str, but got {type(style)}" + ) + + if style == "colwise": + return ColwiseParallel() + elif style == "rowwise": + return RowwiseParallel() + elif style == "colwise_rep": + return ColwiseParallel(output_layouts=Replicate()) + elif style == "rowwise_rep": + return RowwiseParallel(input_layouts=Replicate()) + elif style == "sequence_parallel": + return SequenceParallel() + else: + raise ValueError(f"Unknown parallel style: {style}") + + +def get_hf_tp_plan(model): + """Get the Hugging Face tensor parallel plan from the model.""" + hf_tp_plan = {} + + # model_cls._tp_plan will override model_cls after xxxForCausalLM.post_init() (transformers==4.51.3) + model_cls = type(model) + if hasattr(model_cls, "_tp_plan") and model_cls._tp_plan is not None: + hf_tp_plan.update(model_cls._tp_plan) + + if hasattr(model, "_tp_plan") and model._tp_plan is not None: + hf_tp_plan.update(model._tp_plan) + + if hasattr(model.model, "_tp_plan") and model.model._tp_plan is not None: + hf_tp_plan.update({f"model.{k}": v for k, v in model.model._tp_plan.items()}) + + hf_tp_plan = {k: translate_parallel_style(v) for k, v in hf_tp_plan.items()} + return hf_tp_plan + + def _parallelize_model( model: Union[Qwen2ForCausalLM, LlamaForCausalLM], dp_mesh: DeviceMesh, @@ -354,6 +335,7 @@ def _parallelize_model( sequence_parallel: bool = False, activation_checkpointing: bool = False, cpu_offload: bool = False, + custom_parallel_plan: dict = None, ): """Parallelize a model using DTensor. @@ -372,6 +354,8 @@ def _parallelize_model( Raises: ValueError: If the model type is not supported for parallelization. """ + model_cls = type(model) + mp_policy = MixedPrecisionPolicy( param_dtype=param_dtype, reduce_dtype=torch.float32, @@ -383,20 +367,66 @@ def _parallelize_model( else torch.distributed.fsdp.OffloadPolicy ) - model_cls = type(model) - if model_cls not in PARALLIZE_FUNCTIONS: - raise ValueError(f"Model {model_cls} not supported as part of dtensor") - - func = PARALLIZE_FUNCTIONS[type(model)] - - return func( - model, - dp_mesh, - tp_mesh, - mp_policy, - offload_policy, - sequence_parallel, - activation_checkpointing, + if tp_mesh.size() > 1: + # first use user's custom parallel plan + if custom_parallel_plan is not None: + model_parallel_plan = { + k: translate_parallel_style(v) for k, v in custom_parallel_plan.items() + } + + # second use our optimized parallel plan + elif model_cls in PARALLIZE_FUNCTIONS: + # try to use our optimized parallel plan + try: + func = PARALLIZE_FUNCTIONS[model_cls] + model_parallel_plan = func( + model, + dp_mesh, + tp_mesh, + mp_policy, + offload_policy, + sequence_parallel, + activation_checkpointing, + ) + # fall back to the HF tp plan + except Exception as e: + print( + f"Optimized parallel plan is not available: {e}. Falling back to the HF tp plan." + ) + assert not sequence_parallel, ( + "sequence_parallel is not support in HF tp plan." + ) + model_parallel_plan = get_hf_tp_plan(model) + + # final use the default HF tp plan + else: + # optimized parallel plan is not support for the model class + print( + f"Optimized parallel plan is not support for {model_cls}. Falling back to the HF tp plan." + ) + assert not sequence_parallel, ( + "sequence_parallel is not support in HF tp plan." + ) + model_parallel_plan = get_hf_tp_plan(model) + + parallelize_module(model, tp_mesh, model_parallel_plan) + + if model_cls == Gemma3ForConditionalGeneration: + layers = model.language_model.model.layers + else: + layers = model.model.layers + + if activation_checkpointing: + for i in range(len(layers)): + layers[i].mlp = checkpoint_wrapper(layers[i].mlp) # type: ignore + + for layer in layers: + fully_shard( + layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy + ) + + return fully_shard( + model, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy ) diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index e0f7916835..f43af36d2b 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -23,6 +23,7 @@ class DTensorConfig(TypedDict): sequence_parallel: bool activation_checkpointing: bool tensor_parallel_size: int + custom_parallel_plan: dict class TokenizerConfig(TypedDict): diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 0108d7e8b6..1ed222f314 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -192,6 +192,7 @@ def __init__( activation_checkpointing=self.cfg["dtensor_cfg"][ "activation_checkpointing" ], + custom_parallel_plan=self.cfg["dtensor_cfg"]["custom_parallel_plan"], ) if self.cpu_offload: From 807a0dc09ab9502865f97dacb34d307b692af0bc Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 13 May 2025 10:59:43 +0000 Subject: [PATCH 2/9] tidy up Signed-off-by: Yuki Huang --- nemo_rl/models/dtensor/parallelize.py | 89 +++++++++++---------------- 1 file changed, 36 insertions(+), 53 deletions(-) diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index 0829b20b90..c5c1122787 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -81,12 +81,7 @@ def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): def _parallelize_gemma3( model: Union[Gemma3ForCausalLM, Gemma3ForConditionalGeneration], - dp_mesh: DeviceMesh, - tp_mesh: DeviceMesh, - mp_policy: MixedPrecisionPolicy, - offload_policy: torch.distributed.fsdp.OffloadPolicy, sequence_parallel: bool = False, - activation_checkpointing: bool = False, ): """Parallelizes a Gemma3ForCausalLM model across data parallel dimensions. @@ -94,19 +89,8 @@ def _parallelize_gemma3( """ if isinstance(model, Gemma3ForConditionalGeneration): model_prefix = "language_model.model" - num_attention_heads = model.config.text_config.num_attention_heads - num_key_value_heads = model.config.text_config.num_key_value_heads else: model_prefix = "model" - num_attention_heads = model.config.num_attention_heads - num_key_value_heads = model.config.num_key_value_heads - - assert num_key_value_heads % tp_mesh.size() == 0, ( - f"num_key_value_heads ({num_key_value_heads}) must be divisible by TP size ({tp_mesh.size()})" - ) - assert num_attention_heads % tp_mesh.size() == 0, ( - f"num_attention_heads ({num_attention_heads}) must be divisible by TP size ({tp_mesh.size()})" - ) # For gemma3 models, we don't include the model.embed_tokens and lm_head in the # parallelization plans because they have tied weights. @@ -155,12 +139,7 @@ def _parallelize_gemma3( def _parallelize_llama( model: LlamaForCausalLM, - dp_mesh: DeviceMesh, - tp_mesh: DeviceMesh, - mp_policy: MixedPrecisionPolicy, - offload_policy: torch.distributed.fsdp.OffloadPolicy, sequence_parallel: bool = False, - activation_checkpointing: bool = False, ): """Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions.""" assert not model.config.tie_word_embeddings, ( @@ -202,12 +181,7 @@ def _parallelize_llama( def _parallelize_qwen( model: Union[Qwen2ForCausalLM, Qwen3ForCausalLM], - dp_mesh: DeviceMesh, - tp_mesh: DeviceMesh, - mp_policy: MixedPrecisionPolicy, - offload_policy: torch.distributed.fsdp.OffloadPolicy, sequence_parallel: bool = False, - activation_checkpointing: bool = False, ): """Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions.""" @@ -289,7 +263,10 @@ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): def translate_parallel_style(style: str): - """Translate parallel style str to parallel type""" + """Translate parallel style str to parallel type. + + Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L547 + """ assert isinstance(style, str), ( f"parallel style type should be str, but got {type(style)}" ) @@ -309,7 +286,10 @@ def translate_parallel_style(style: str): def get_hf_tp_plan(model): - """Get the Hugging Face tensor parallel plan from the model.""" + """Get the Hugging Face tensor parallel plan from the model. + + Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L532 + """ hf_tp_plan = {} # model_cls._tp_plan will override model_cls after xxxForCausalLM.post_init() (transformers==4.51.3) @@ -355,19 +335,23 @@ def _parallelize_model( ValueError: If the model type is not supported for parallelization. """ model_cls = type(model) - - mp_policy = MixedPrecisionPolicy( - param_dtype=param_dtype, - reduce_dtype=torch.float32, - output_dtype=torch.float32, - ) - offload_policy = ( - CPUOffloadPolicy(pin_memory=False) - if cpu_offload - else torch.distributed.fsdp.OffloadPolicy - ) + if model_cls == Gemma3ForConditionalGeneration: + layers = model.language_model.model.layers + num_attention_heads = model.config.text_config.num_attention_heads + num_key_value_heads = model.config.text_config.num_key_value_heads + else: + layers = model.model.layers + num_attention_heads = model.config.num_attention_heads + num_key_value_heads = model.config.num_key_value_heads if tp_mesh.size() > 1: + assert num_key_value_heads % tp_mesh.size() == 0, ( + f"num_key_value_heads ({num_key_value_heads}) must be divisible by TP size ({tp_mesh.size()})" + ) + assert num_attention_heads % tp_mesh.size() == 0, ( + f"num_attention_heads ({num_attention_heads}) must be divisible by TP size ({tp_mesh.size()})" + ) + # first use user's custom parallel plan if custom_parallel_plan is not None: model_parallel_plan = { @@ -379,15 +363,7 @@ def _parallelize_model( # try to use our optimized parallel plan try: func = PARALLIZE_FUNCTIONS[model_cls] - model_parallel_plan = func( - model, - dp_mesh, - tp_mesh, - mp_policy, - offload_policy, - sequence_parallel, - activation_checkpointing, - ) + model_parallel_plan = func(model, sequence_parallel) # fall back to the HF tp plan except Exception as e: print( @@ -411,15 +387,22 @@ def _parallelize_model( parallelize_module(model, tp_mesh, model_parallel_plan) - if model_cls == Gemma3ForConditionalGeneration: - layers = model.language_model.model.layers - else: - layers = model.model.layers - if activation_checkpointing: for i in range(len(layers)): layers[i].mlp = checkpoint_wrapper(layers[i].mlp) # type: ignore + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype, + reduce_dtype=torch.float32, + output_dtype=torch.float32, + ) + + offload_policy = ( + CPUOffloadPolicy(pin_memory=False) + if cpu_offload + else torch.distributed.fsdp.OffloadPolicy + ) + for layer in layers: fully_shard( layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy From 440cd35a0fa7734f057ea3818d572d83130196cb Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 14 May 2025 23:38:05 -0700 Subject: [PATCH 3/9] fix model with model.language_model Signed-off-by: Yuki Huang --- nemo_rl/models/dtensor/parallelize.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index c5c1122787..24a905e0d7 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -290,18 +290,27 @@ def get_hf_tp_plan(model): Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L532 """ + model_cls = type(model) + if model_cls == Gemma3ForConditionalGeneration: + inner_model = model.language_model + model_prefix = "language_model" + else: + inner_model = model.model + model_prefix = "model" + hf_tp_plan = {} # model_cls._tp_plan will override model_cls after xxxForCausalLM.post_init() (transformers==4.51.3) - model_cls = type(model) if hasattr(model_cls, "_tp_plan") and model_cls._tp_plan is not None: hf_tp_plan.update(model_cls._tp_plan) if hasattr(model, "_tp_plan") and model._tp_plan is not None: hf_tp_plan.update(model._tp_plan) - if hasattr(model.model, "_tp_plan") and model.model._tp_plan is not None: - hf_tp_plan.update({f"model.{k}": v for k, v in model.model._tp_plan.items()}) + if hasattr(inner_model, "_tp_plan") and inner_model._tp_plan is not None: + hf_tp_plan.update( + {f"{model_prefix}.{k}": v for k, v in inner_model._tp_plan.items()} + ) hf_tp_plan = {k: translate_parallel_style(v) for k, v in hf_tp_plan.items()} return hf_tp_plan From be8ddebf1f63bd5360136d3da2158ada132a3663 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 19 May 2025 07:44:54 +0000 Subject: [PATCH 4/9] special with embed_tokens and lm_head for speed up Signed-off-by: Yuki Huang --- nemo_rl/models/dtensor/parallelize.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index 24a905e0d7..8839b12bb5 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import lru_cache from typing import Callable, Union import torch @@ -262,6 +263,7 @@ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): } +@lru_cache def translate_parallel_style(style: str): """Translate parallel style str to parallel type. @@ -312,7 +314,30 @@ def get_hf_tp_plan(model): {f"{model_prefix}.{k}": v for k, v in inner_model._tp_plan.items()} ) - hf_tp_plan = {k: translate_parallel_style(v) for k, v in hf_tp_plan.items()} + assert len(hf_tp_plan) > 0, ( + f"Hugging Face tp plan is not supported for {model_cls}, please set dtensor_cfg.tensor_parallel_size to 1 or provide a custom parallel plan." + ) + + # hf tp plan not contain embed_tokens, we add it and set to rowwise_rep + if ( + f"{model_prefix}.embed_tokens" not in hf_tp_plan + and not model.config.tie_word_embeddings + ): + hf_tp_plan[f"{model_prefix}.embed_tokens"] = "rowwise_rep" + + for k, v in hf_tp_plan.items(): + # speed up the tp plan for lm_head + if ( + k == "lm_head" + and v == "colwise_rep" + and not model.config.tie_word_embeddings + ): + hf_tp_plan[k] = ColwiseParallel( + output_layouts=Shard(-1), use_local_output=False + ) + else: + hf_tp_plan[k] = translate_parallel_style(v) + return hf_tp_plan From 3e6918f718b22025f9b8285e1ce90618b651c895 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 20 May 2025 08:39:03 +0000 Subject: [PATCH 5/9] add doc and update custom_parallel_plan Signed-off-by: Yuki Huang --- docs/design-docs/fsdp2-parallel-plan.md | 49 +++++++++++++++++++++++++ docs/index.md | 1 + nemo_rl/models/dtensor/parallelize.py | 19 ++++++++-- nemo_rl/models/policy/__init__.py | 2 +- 4 files changed, 66 insertions(+), 5 deletions(-) create mode 100644 docs/design-docs/fsdp2-parallel-plan.md diff --git a/docs/design-docs/fsdp2-parallel-plan.md b/docs/design-docs/fsdp2-parallel-plan.md new file mode 100644 index 0000000000..db04ee8484 --- /dev/null +++ b/docs/design-docs/fsdp2-parallel-plan.md @@ -0,0 +1,49 @@ +# FSDP2 Parallel Plan + +This guide outlines the parallelization strategy for FSDP2 training in NeMo-RL. + +## Fallback Priority + +Three parallelization approaches are supported, with the following fallback priority. + +**Custom Parallel Plan** + +User-defined custom parallel plans take precedence when available. + +For implementation details and usage guidelines, please refer to [Custom Parallel Plan Example](#custom-parallel-plan-example). + +**Optimized Parallel Plan** + +Optimized parallel plans are available for specific model architectures and may offer superior performance compared to the Hugging Face tensor parallel implementation. + +This approach is used when no custom parallel plan is specified and the model class supports optimized parallelization. + +**Hugging Face Tensor Parallel Plan** + +Hugging Face provides tensor parallelism for most models through `._tp_plan`. + +It serves as the default when neither custom nor optimized parallel plans are available. + +## Custom Parallel Plan Example + +Custom parallel plan should be defined in a file, exemplified by `examples/custom_parallel.py`. + +To implement the custom parallel plan, configure `policy.dtensor_cfg.custom_parallel_plan=examples.custom_parallel.custom_parallel_plan`. + +```python +from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel +from torch.distributed.tensor.placement_types import Replicate, Shard + + +custom_parallel_plan = { + "model.embed_tokens": RowwiseParallel(input_layouts=Replicate()), + "model.layers.*.self_attn.q_proj": ColwiseParallel(), + "model.layers.*.self_attn.k_proj": ColwiseParallel(), + "model.layers.*.self_attn.v_proj": ColwiseParallel(), + "model.layers.*.self_attn.o_proj": RowwiseParallel(), + "model.layers.*.mlp.up_proj": ColwiseParallel(), + "model.layers.*.mlp.gate_proj": ColwiseParallel(), + "model.layers.*.mlp.down_proj": RowwiseParallel(), + "lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False), +} +``` diff --git a/docs/index.md b/docs/index.md index 4a0a5fcaa5..1e50854dfd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -61,4 +61,5 @@ design-docs/chat-datasets.md design-docs/generation.md design-docs/checkpointing.md design-docs/loss-functions.md +design-docs/fsdp2-parallel-plan.md ``` diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index 8839b12bb5..843c22c7e7 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import lru_cache +from types import FunctionType from typing import Callable, Union import torch @@ -40,6 +41,7 @@ from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM from nemo_rl.distributed.model_utils import from_parallel_logits_to_logprobs +from nemo_rl.models.policy.utils import import_class_from_path class RotaryEmbedParallel(SequenceParallel): @@ -349,7 +351,7 @@ def _parallelize_model( sequence_parallel: bool = False, activation_checkpointing: bool = False, cpu_offload: bool = False, - custom_parallel_plan: dict = None, + custom_parallel_plan: Union[dict, str] = None, ): """Parallelize a model using DTensor. @@ -361,6 +363,10 @@ def _parallelize_model( sequence_parallel (bool, optional): Whether to use sequence parallelism. Defaults to False. activation_checkpointing (bool, optional): Whether to use activation checkpointing. Defaults to False. cpu_offload (bool, optional): Whether to enable cpu offloading for FSDP. Defaults to False. + custom_parallel_plan (Union[dict, str], optional): Custom parallel plan for the model. Defaults to None. + If it's a dict, it will be used as the parallel plan directly. + If it's a string, it must be a path that points to a dict or a function that returns a dict. + The usage example can refer to `docs/design-docs/fsdp2-parallel-plan.md`. Returns: The parallelized model. @@ -388,9 +394,13 @@ def _parallelize_model( # first use user's custom parallel plan if custom_parallel_plan is not None: - model_parallel_plan = { - k: translate_parallel_style(v) for k, v in custom_parallel_plan.items() - } + model_parallel_plan = import_class_from_path(custom_parallel_plan) + if isinstance(model_parallel_plan, FunctionType): + model_parallel_plan = model_parallel_plan() + assert isinstance(model_parallel_plan, dict), ( + "custom_parallel_plan must be a path that points to a dict or a function that returns a dict" + ) + print(f"Using custom parallel plan.") # second use our optimized parallel plan elif model_cls in PARALLIZE_FUNCTIONS: @@ -398,6 +408,7 @@ def _parallelize_model( try: func = PARALLIZE_FUNCTIONS[model_cls] model_parallel_plan = func(model, sequence_parallel) + print(f"Using optimized parallel plan.") # fall back to the HF tp plan except Exception as e: print( diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index f43af36d2b..6e1bad7048 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -23,7 +23,7 @@ class DTensorConfig(TypedDict): sequence_parallel: bool activation_checkpointing: bool tensor_parallel_size: int - custom_parallel_plan: dict + custom_parallel_plan: str class TokenizerConfig(TypedDict): From 9f259bb5516da4aeff03ac16523b091efd9988e1 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 20 May 2025 09:49:37 +0000 Subject: [PATCH 6/9] add unit test Signed-off-by: Yuki Huang --- nemo_rl/models/dtensor/parallelize.py | 25 +++++++---- .../models/policy/dtensor_policy_worker.py | 3 ++ .../models/generation/test_vllm_generation.py | 1 + .../unit/models/policy/test_dtensor_worker.py | 43 +++++++++++++------ tests/unit/models/policy/test_fsdp1_worker.py | 1 + tests/unit/utils/test_native_checkpoint.py | 1 + 6 files changed, 54 insertions(+), 20 deletions(-) diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index 843c22c7e7..ad7ba9569b 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -394,13 +394,22 @@ def _parallelize_model( # first use user's custom parallel plan if custom_parallel_plan is not None: - model_parallel_plan = import_class_from_path(custom_parallel_plan) - if isinstance(model_parallel_plan, FunctionType): - model_parallel_plan = model_parallel_plan() - assert isinstance(model_parallel_plan, dict), ( - "custom_parallel_plan must be a path that points to a dict or a function that returns a dict" - ) - print(f"Using custom parallel plan.") + if isinstance(custom_parallel_plan, dict): + model_parallel_plan = custom_parallel_plan + else: + try: + model_parallel_plan = import_class_from_path(custom_parallel_plan) + if isinstance(model_parallel_plan, FunctionType): + model_parallel_plan = model_parallel_plan() + assert isinstance(model_parallel_plan, dict) + except: + raise ValueError( + f"Your custom parallel plan is `{custom_parallel_plan}` which is not valid. Please ensure it is one of the following:\n" + "1. A dictionary\n" + "2. A path to a dictionary\n" + "3. A path to a function that returns a dictionary" + ) + print("Using custom parallel plan.") # second use our optimized parallel plan elif model_cls in PARALLIZE_FUNCTIONS: @@ -408,7 +417,7 @@ def _parallelize_model( try: func = PARALLIZE_FUNCTIONS[model_cls] model_parallel_plan = func(model, sequence_parallel) - print(f"Using optimized parallel plan.") + print("Using optimized parallel plan.") # fall back to the HF tp plan except Exception as e: print( diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 1ed222f314..91d52b77e9 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -661,6 +661,9 @@ def _add_noise_to_weights(self) -> None: p.data.add_(noise) # Add noise in-place torch.cuda.synchronize() + def return_state_dict(self): + return self.model.state_dict() + def report_device_id(self) -> str: """Report the UUID of the current CUDA device using NVML. diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index c31ef8f65d..11192a3ad6 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -85,6 +85,7 @@ def get_basic_hf_test_config(enable_dtensor: bool = False) -> PolicyConfig: "sequence_parallel": False, "activation_checkpointing": False, "tensor_parallel_size": 1, + "custom_parallel_plan": None, }, "dynamic_batching": { "enabled": enable_dtensor, # Dynamic batching is only supported with DTensor diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index 8ea07c7f35..c50efca2d4 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -30,7 +30,6 @@ from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.models.generation import configure_generation_config from nemo_rl.models.policy import PolicyConfig -from nemo_rl.models.policy.dtensor_policy_worker import DTensorPolicyWorker from nemo_rl.models.policy.hf_policy import HfPolicy from tests.unit.conftest import TEST_ASSETS from tests.unit.test_utils import SimpleLoss @@ -42,6 +41,7 @@ def create_test_config( sequence_parallel: bool = False, cpu_offload: bool = False, activation_checkpointing: bool = False, + custom_parallel_plan: str = None, ) -> PolicyConfig: return { "model_name": model_name, @@ -67,6 +67,7 @@ def create_test_config( "sequence_parallel": sequence_parallel, "activation_checkpointing": activation_checkpointing, "tensor_parallel_size": tp, + "custom_parallel_plan": custom_parallel_plan, }, "dynamic_batching": { "enabled": True, @@ -461,25 +462,43 @@ def test_dtensor_worker_logprob_tp2_matches_no_tp(logprob_setup): ) -def test_dtensor_fails_with_tp_and_tied_model(mock_2gpu_distributed_env): - """Test that DTensor fails with a tp > 1 and a tied model.""" +def test_dtensor_tp_and_tied_model_with_custom_parallel_plan(two_gpu_virtual_cluster): + """Test that DTensor with a tp > 1 and a tied model with a custom parallel plan works.""" + from torch.distributed.tensor.parallel import ColwiseParallel + from torch.distributed.tensor.placement_types import Replicate + + custom_parallel_plan = {"lm_head": ColwiseParallel(output_layouts=Replicate())} config = create_test_config( model_name=TEST_ASSETS.TINY_LLAMA_TIED_MODEL_PATH, tp=2, cpu_offload=False, sequence_parallel=False, activation_checkpointing=False, + custom_parallel_plan=custom_parallel_plan, ) tokenizer = get_tokenizer(config["tokenizer"]) - with pytest.raises( - AssertionError, match="Tie word embeddings not supported when TP is enabled" - ): - DTensorPolicyWorker.__ray_actor_class__( - config=config, - tokenizer=tokenizer, - init_optimizer=False, - init_reference_model=False, - ) + + policy = HfPolicy( + tokenizer=tokenizer, + config=config, + init_optimizer=False, + init_reference_model=False, + cluster=two_gpu_virtual_cluster, + ) + + # Verify that the model is parallelized as expected + state_dict = ray.get(policy.worker_group.workers[0].return_state_dict.remote()) + total_shape = state_dict["lm_head.weight"].shape + sharded_shape = state_dict["lm_head.weight"].to_local().shape + assert total_shape[0] == sharded_shape[0] * 2, ( + "lm_head.weight should be sharded across 2 GPUs" + ) + assert total_shape[1] == sharded_shape[1], ( + "lm_head.weight should have the same number of columns" + ) + + # Clean up + policy.shutdown() @pytest.mark.timeout(180) diff --git a/tests/unit/models/policy/test_fsdp1_worker.py b/tests/unit/models/policy/test_fsdp1_worker.py index 591acd8749..3ec19f5148 100644 --- a/tests/unit/models/policy/test_fsdp1_worker.py +++ b/tests/unit/models/policy/test_fsdp1_worker.py @@ -58,6 +58,7 @@ "sequence_parallel": False, "activation_checkpointing": False, "tensor_parallel_size": 1, + "custom_parallel_plan": None, }, "dynamic_batching": { "enabled": False, diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py index a53d7dd44a..68a1f3e217 100755 --- a/tests/unit/utils/test_native_checkpoint.py +++ b/tests/unit/utils/test_native_checkpoint.py @@ -59,6 +59,7 @@ "sequence_parallel": False, "activation_checkpointing": False, "tensor_parallel_size": 1, + "custom_parallel_plan": None, }, "dynamic_batching": { "enabled": False, From 10c4f05a19b0cc97562a79f6746aee64da60e3d3 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 20 May 2025 09:58:05 +0000 Subject: [PATCH 7/9] update config Signed-off-by: Yuki Huang --- examples/configs/grpo-deepscaler-1.5b-8K.yaml | 3 ++- examples/configs/grpo_deepscaler-1.5b-24K.yaml | 1 + .../llm/dpo-llama3.1-8b-instruct-4n8g-fsdp1-quick.v2.yaml | 2 ++ .../recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml | 2 ++ .../llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml | 2 ++ .../recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml | 2 ++ .../configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml | 1 + .../llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml | 1 + .../llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml | 1 + .../llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml | 1 + .../llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml | 1 + .../llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml | 1 + .../recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml | 1 + .../llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml | 1 + .../llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml | 1 + .../recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp1.v2.yaml | 1 + .../llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml | 1 + .../llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml | 1 + .../configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml | 1 + .../llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml | 1 + examples/configs/sft_openmathinstruct2.yaml | 1 + 21 files changed, 26 insertions(+), 1 deletion(-) diff --git a/examples/configs/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/grpo-deepscaler-1.5b-8K.yaml index ecaca68d58..f37f3140c2 100644 --- a/examples/configs/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-8K.yaml @@ -50,7 +50,8 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 - + custom_parallel_plan: null + # makes the training sequence length divisible by the tensor parallel size # this is useful for sequence parallel training make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} diff --git a/examples/configs/grpo_deepscaler-1.5b-24K.yaml b/examples/configs/grpo_deepscaler-1.5b-24K.yaml index 9beb0c210b..a616ccfa6a 100644 --- a/examples/configs/grpo_deepscaler-1.5b-24K.yaml +++ b/examples/configs/grpo_deepscaler-1.5b-24K.yaml @@ -15,6 +15,7 @@ policy: sequence_parallel: true activation_checkpointing: true tensor_parallel_size: 4 + custom_parallel_plan: null optimizer: name: "torch.optim.AdamW" diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp1-quick.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp1-quick.v2.yaml index c953e8ecd4..42b287565e 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp1-quick.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp1-quick.v2.yaml @@ -39,6 +39,8 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null + dynamic_batching: enabled: False diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml index 2ce7ec018a..563c9462d4 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml @@ -39,6 +39,8 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null + dynamic_batching: enabled: False diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml index c68e9af08c..74a0ecd3d1 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml @@ -39,6 +39,8 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 2 + custom_parallel_plan: null + dynamic_batching: enabled: False diff --git a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml index da1a95ac0a..3c4d31f324 100644 --- a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml @@ -40,6 +40,8 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null + dynamic_batching: enabled: False diff --git a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml index 330104005b..2eca1f773e 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} diff --git a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml index ff297d4a23..fbfad65fc5 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: true activation_checkpointing: true tensor_parallel_size: 8 + custom_parallel_plan: null dynamic_batching: # TODO: OOMs if enabled https://github.com/NVIDIA/NeMo-RL/issues/383 enabled: False diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml index ee26b2dfa3..d163f3e130 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml index 2ae26cda4a..19722c4746 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml index e609a45558..15e539dcf3 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: true activation_checkpointing: true tensor_parallel_size: 8 + custom_parallel_plan: null dynamic_batching: enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml index f8b37a53f8..e21526145d 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: true activation_checkpointing: true tensor_parallel_size: 8 + custom_parallel_plan: null dynamic_batching: enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml index 92ffcfef59..8b9a0f5c62 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: False make_sequence_length_divisible_by: 1 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml index bf196e3d37..e9d63dc158 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: true activation_checkpointing: false tensor_parallel_size: 4 + custom_parallel_plan: null dynamic_batching: enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml index 0099e9ebd7..49abe12153 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -43,6 +43,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp1.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp1.v2.yaml index c089bcd4d0..d39bafbe91 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp1.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp1.v2.yaml @@ -31,6 +31,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: False make_sequence_length_divisible_by: 1 diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml index 86537fcb87..8f62399bde 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml @@ -31,6 +31,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: False make_sequence_length_divisible_by: 1 diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml index aec0d380c2..c4385da6e4 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml @@ -31,6 +31,7 @@ policy: sequence_parallel: true activation_checkpointing: false tensor_parallel_size: 2 + custom_parallel_plan: null dynamic_batching: enabled: False make_sequence_length_divisible_by: 2 diff --git a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml index bc9b6f6326..4d31d392e9 100644 --- a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml @@ -31,6 +31,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 + custom_parallel_plan: null dynamic_batching: enabled: False make_sequence_length_divisible_by: 1 diff --git a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml index 7d9cfe6eb1..6baa5f08f9 100644 --- a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml +++ b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml @@ -31,6 +31,7 @@ policy: sequence_parallel: true activation_checkpointing: true tensor_parallel_size: 8 + custom_parallel_plan: null dynamic_batching: enabled: False make_sequence_length_divisible_by: 8 diff --git a/examples/configs/sft_openmathinstruct2.yaml b/examples/configs/sft_openmathinstruct2.yaml index ef853cc4cd..b885f7388b 100644 --- a/examples/configs/sft_openmathinstruct2.yaml +++ b/examples/configs/sft_openmathinstruct2.yaml @@ -34,6 +34,7 @@ policy: sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 4 + custom_parallel_plan: null # makes the training sequence length divisible by the tensor parallel size # this is useful for sequence parallel training From 8ae4d329ee124c291cbeb2b6378133ef878f4935 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 21 May 2025 19:50:50 -0700 Subject: [PATCH 8/9] fix gemma2 Signed-off-by: Yuki Huang --- nemo_rl/models/huggingface/common.py | 14 ++++++++------ tests/unit/models/huggingface/test_common.py | 16 +++++++++++----- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/nemo_rl/models/huggingface/common.py b/nemo_rl/models/huggingface/common.py index 10fa3f4cfa..df913f95b4 100644 --- a/nemo_rl/models/huggingface/common.py +++ b/nemo_rl/models/huggingface/common.py @@ -39,15 +39,17 @@ class ModelFlag(Enum): def matches(self, model_name: str) -> bool: match self: case ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK: - return is_gemma3_model(model_name) + return is_gemma_model(model_name) case ModelFlag.VLLM_LOAD_FORMAT_AUTO: - return is_gemma3_model(model_name) + return is_gemma_model(model_name) case _: raise ValueError(f"Unknown ModelFlag: {self}") -def is_gemma3_model(model_name: str) -> bool: +def is_gemma_model(model_name: str) -> bool: hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - return hasattr(hf_config, "model_type") and ( - hf_config.model_type == "gemma3" or hf_config.model_type == "gemma3_text" - ) + return hasattr(hf_config, "model_type") and hf_config.model_type in [ + "gemma2", + "gemma3", + "gemma3_text", + ] diff --git a/tests/unit/models/huggingface/test_common.py b/tests/unit/models/huggingface/test_common.py index 74fb2f0848..faf06fbdb7 100644 --- a/tests/unit/models/huggingface/test_common.py +++ b/tests/unit/models/huggingface/test_common.py @@ -14,12 +14,18 @@ import pytest -from nemo_rl.models.huggingface.common import ModelFlag, is_gemma3_model +from nemo_rl.models.huggingface.common import ModelFlag, is_gemma_model @pytest.mark.parametrize( "model_name", [ + "google/gemma-2-2b", + "google/gemma-2-9b", + "google/gemma-2-27b", + "google/gemma-2-2b-it", + "google/gemma-2-9b-it", + "google/gemma-2-27b-it", "google/gemma-3-1b-pt", "google/gemma-3-4b-pt", "google/gemma-3-12b-pt", @@ -30,8 +36,8 @@ "google/gemma-3-27b-it", ], ) -def test_gemma3_models(model_name): - assert is_gemma3_model(model_name) +def test_gemma_models(model_name): + assert is_gemma_model(model_name) assert ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK.matches(model_name) assert ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(model_name) @@ -44,7 +50,7 @@ def test_gemma3_models(model_name): "Qwen/Qwen2.5-3B-Instruct", ], ) -def test_non_gemma3_models(model_name): - assert not is_gemma3_model(model_name) +def test_non_gemma_models(model_name): + assert not is_gemma_model(model_name) assert not ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK.matches(model_name) assert not ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(model_name) From 72a8f35c26314a6b48b1ea4d9ec3f89d1d854d95 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 21 May 2025 23:59:32 -0700 Subject: [PATCH 9/9] update doc and fix type Signed-off-by: Yuki Huang --- docs/design-docs/fsdp2-parallel-plan.md | 45 ++++++--------------- examples/custom_parallel.py | 28 +++++++++++++ nemo_rl/models/dtensor/parallelize.py | 53 ++++++++++++++++++------- 3 files changed, 79 insertions(+), 47 deletions(-) create mode 100644 examples/custom_parallel.py diff --git a/docs/design-docs/fsdp2-parallel-plan.md b/docs/design-docs/fsdp2-parallel-plan.md index db04ee8484..8318b3174c 100644 --- a/docs/design-docs/fsdp2-parallel-plan.md +++ b/docs/design-docs/fsdp2-parallel-plan.md @@ -1,49 +1,30 @@ # FSDP2 Parallel Plan -This guide outlines the parallelization strategy for FSDP2 training in NeMo-RL. +This guide outlines the parallelization strategy for Fully Sharded Data Parallel version 2 (FSDP2) training in NeMo RL. ## Fallback Priority -Three parallelization approaches are supported, with the following fallback priority. +NeMo RL supports three parallelization strategies, applied in the following order of fallback priority: -**Custom Parallel Plan** +### 1. Custom Parallel Plan -User-defined custom parallel plans take precedence when available. +Your user-defined custom parallel plans always take precedence when available. For detailed implementation and usage, refer to the [Custom Parallel Plan Example](#custom-parallel-plan-example). -For implementation details and usage guidelines, please refer to [Custom Parallel Plan Example](#custom-parallel-plan-example). +### 2. Optimized Parallel Plan -**Optimized Parallel Plan** +Optimized parallel plans are available for specific model architectures. They may offer superior performance compared to Hugging Face's tensor parallel implementation. This approach is used if no custom parallel plan is specified and the model class supports optimized parallelization. -Optimized parallel plans are available for specific model architectures and may offer superior performance compared to the Hugging Face tensor parallel implementation. +### 3. Hugging Face Tensor Parallel Plan -This approach is used when no custom parallel plan is specified and the model class supports optimized parallelization. - -**Hugging Face Tensor Parallel Plan** - -Hugging Face provides tensor parallelism for most models through `._tp_plan`. - -It serves as the default when neither custom nor optimized parallel plans are available. +The Hugging Face tensor parallel plan is the default. It's available for most models via `._tp_plan` and is used when neither a custom nor an optimized parallel plan is available. ## Custom Parallel Plan Example -Custom parallel plan should be defined in a file, exemplified by `examples/custom_parallel.py`. - -To implement the custom parallel plan, configure `policy.dtensor_cfg.custom_parallel_plan=examples.custom_parallel.custom_parallel_plan`. - -```python -from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel -from torch.distributed.tensor.placement_types import Replicate, Shard +A custom parallel plan should be defined in a separate file, such as the example provided in `examples/custom_parallel.py`. +To implement the custom parallel plan, either update the value of `custom_parallel_plan` in the `yaml` file directly, or pass the override via the command line. For example: -custom_parallel_plan = { - "model.embed_tokens": RowwiseParallel(input_layouts=Replicate()), - "model.layers.*.self_attn.q_proj": ColwiseParallel(), - "model.layers.*.self_attn.k_proj": ColwiseParallel(), - "model.layers.*.self_attn.v_proj": ColwiseParallel(), - "model.layers.*.self_attn.o_proj": RowwiseParallel(), - "model.layers.*.mlp.up_proj": ColwiseParallel(), - "model.layers.*.mlp.gate_proj": ColwiseParallel(), - "model.layers.*.mlp.down_proj": RowwiseParallel(), - "lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False), -} +```bash +uv run examples/run_grpo_math.py \ + policy.dtensor_cfg.custom_parallel_plan=examples.custom_parallel.custom_parallel_plan ``` diff --git a/examples/custom_parallel.py b/examples/custom_parallel.py new file mode 100644 index 0000000000..647ddfc563 --- /dev/null +++ b/examples/custom_parallel.py @@ -0,0 +1,28 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel +from torch.distributed.tensor.placement_types import Replicate, Shard + +custom_parallel_plan = { + "model.embed_tokens": RowwiseParallel(input_layouts=Replicate()), + "model.layers.*.self_attn.q_proj": ColwiseParallel(), + "model.layers.*.self_attn.k_proj": ColwiseParallel(), + "model.layers.*.self_attn.v_proj": ColwiseParallel(), + "model.layers.*.self_attn.o_proj": RowwiseParallel(), + "model.layers.*.mlp.up_proj": ColwiseParallel(), + "model.layers.*.mlp.gate_proj": ColwiseParallel(), + "model.layers.*.mlp.down_proj": RowwiseParallel(), + "lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False), +} diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index ad7ba9569b..0d0c4ee413 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -14,7 +14,7 @@ from functools import lru_cache from types import FunctionType -from typing import Callable, Union +from typing import Callable, Optional, Union import torch from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -25,6 +25,7 @@ from torch.distributed.tensor import DTensor from torch.distributed.tensor.parallel import ( ColwiseParallel, + ParallelStyle, PrepareModuleInput, PrepareModuleOutput, RowwiseParallel, @@ -254,7 +255,9 @@ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): return base_model_tp_plan -PARALLIZE_FUNCTIONS: dict[type[torch.nn.Module], Callable[..., torch.nn.Module]] = { +PARALLIZE_FUNCTIONS: dict[ + type[torch.nn.Module], Callable[..., dict[str, ParallelStyle]] +] = { Qwen2ForCausalLM: _parallelize_qwen, Qwen3ForCausalLM: _parallelize_qwen, LlamaForCausalLM: _parallelize_llama, @@ -292,7 +295,21 @@ def translate_parallel_style(style: str): def get_hf_tp_plan(model): """Get the Hugging Face tensor parallel plan from the model. + This function: + - Retrieves TP strategies from model class, instance, and inner model levels. + - Handles special cases for `embed_tokens` and `lm_head` for speed up. + - Converts string-based parallel styles to DTensor parallelization strategies. + Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L532 + + Args: + model: A Hugging Face model instance + + Returns: + dict: A dictionary mapping model component paths to their parallelization strategies + + Raises: + AssertionError: If no TP plan is found """ model_cls = type(model) if model_cls == Gemma3ForConditionalGeneration: @@ -317,7 +334,8 @@ def get_hf_tp_plan(model): ) assert len(hf_tp_plan) > 0, ( - f"Hugging Face tp plan is not supported for {model_cls}, please set dtensor_cfg.tensor_parallel_size to 1 or provide a custom parallel plan." + f"Hugging Face tp plan is not supported for {model_cls}, please set dtensor_cfg.tensor_parallel_size to 1 or provide a custom_parallel_plan. " + "The usage example of custom_parallel_plan can refer to `docs/design-docs/fsdp2-parallel-plan.md`." ) # hf tp plan not contain embed_tokens, we add it and set to rowwise_rep @@ -344,26 +362,31 @@ def get_hf_tp_plan(model): def _parallelize_model( - model: Union[Qwen2ForCausalLM, LlamaForCausalLM], + model: Union[ + Qwen2ForCausalLM, + LlamaForCausalLM, + Gemma3ForCausalLM, + Gemma3ForConditionalGeneration, + ], dp_mesh: DeviceMesh, tp_mesh: DeviceMesh, param_dtype: torch.dtype, sequence_parallel: bool = False, activation_checkpointing: bool = False, cpu_offload: bool = False, - custom_parallel_plan: Union[dict, str] = None, + custom_parallel_plan: Optional[Union[dict, str]] = None, ): """Parallelize a model using DTensor. Args: - model (Union[Qwen2ForCausalLM, LlamaForCausalLM]): The model to parallelize. - dp_mesh (DeviceMesh): Device mesh for data parallelism. - tp_mesh (DeviceMesh): Device mesh for tensor parallelism. - param_dtype (torch.dtype): Data type for model parameters. - sequence_parallel (bool, optional): Whether to use sequence parallelism. Defaults to False. - activation_checkpointing (bool, optional): Whether to use activation checkpointing. Defaults to False. - cpu_offload (bool, optional): Whether to enable cpu offloading for FSDP. Defaults to False. - custom_parallel_plan (Union[dict, str], optional): Custom parallel plan for the model. Defaults to None. + model: The model to parallelize. + dp_mesh: Device mesh for data parallelism. + tp_mesh: Device mesh for tensor parallelism. + param_dtype: Data type for model parameters. + sequence_parallel: Whether to use sequence parallelism. Defaults to False. + activation_checkpointing: Whether to use activation checkpointing. Defaults to False. + cpu_offload: Whether to enable cpu offloading for FSDP. Defaults to False. + custom_parallel_plan: Custom parallel plan for the model. Defaults to None. If it's a dict, it will be used as the parallel plan directly. If it's a string, it must be a path that points to a dict or a function that returns a dict. The usage example can refer to `docs/design-docs/fsdp2-parallel-plan.md`. @@ -376,11 +399,11 @@ def _parallelize_model( """ model_cls = type(model) if model_cls == Gemma3ForConditionalGeneration: - layers = model.language_model.model.layers + layers: torch.nn.ModuleList = model.language_model.model.layers # type: ignore num_attention_heads = model.config.text_config.num_attention_heads num_key_value_heads = model.config.text_config.num_key_value_heads else: - layers = model.model.layers + layers: torch.nn.ModuleList = model.model.layers # type: ignore num_attention_heads = model.config.num_attention_heads num_key_value_heads = model.config.num_key_value_heads