From afd6b829266b59ee8d080d646689b68daed7e68e Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Fri, 19 Dec 2025 00:06:33 -0800 Subject: [PATCH] Fix crash when using cp in dtensor path Signed-off-by: Yi-Fu Wu --- .../policy/workers/dtensor_policy_worker_v2.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index a8a3957bd5..4ec712929b 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -277,10 +277,24 @@ def __init__( # All ranks initialize model on meta device, so FSDP can shard it. # The actual weights will be broadcast from rank 0. + cp_size = self.cfg["dtensor_cfg"]["context_parallel_size"] with init_empty_weights(): # NeMoAutoModelForCausalLM uses flash_attention_2 by default # so we need to set it to None if sequence packing is disabled # https://github.com/NVIDIA-NeMo/Automodel/blob/7e748be260651349307862426c0c168cebdeeec3/nemo_automodel/components/_transformers/auto_model.py#L180 + if cp_size > 1: + # Match Automodel's `get_train_context` in `cp_utils.py` where only + # flash and efficient backends are supported + # Ref: https://github.com/NVIDIA-NeMo/Automodel/blob/81788d6f4848f5f066c4a6a2bece4689a6a83687/nemo_automodel/components/distributed/cp_utils.py#L57 + from torch.nn.attention import SDPBackend + + sdpa_method = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + ] + else: + sdpa_method = None + self.model = model_class.from_config( model_config, attn_implementation="flash_attention_2" @@ -289,6 +303,7 @@ def __init__( use_liger_kernel=False, trust_remote_code=True, torch_dtype=str(model_config.torch_dtype), + sdpa_method=sdpa_method, ) if self.lora_enabled: apply_lora_to_linear_modules(self.model, self.peft_config) @@ -297,7 +312,6 @@ def __init__( self.model.config.pad_token_id = tokenizer.pad_token_id tp_size = self.cfg["dtensor_cfg"]["tensor_parallel_size"] - cp_size = self.cfg["dtensor_cfg"]["context_parallel_size"] if cp_size > 1 and self.enable_seq_packing: raise ValueError( "Context parallel is not supported for sequence packing. Refer to https://github.com/NVIDIA/NeMo-RL/blob/main/docs/model-quirks.md#context-parallel-with-fsdp2 for more details."