Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,24 @@ def __init__(
# All ranks initialize model on meta device, so FSDP can shard it.
# The actual weights will be broadcast from rank 0.

cp_size = self.cfg["dtensor_cfg"]["context_parallel_size"]
with init_empty_weights():
# NeMoAutoModelForCausalLM uses flash_attention_2 by default
# so we need to set it to None if sequence packing is disabled
# https://github.com/NVIDIA-NeMo/Automodel/blob/7e748be260651349307862426c0c168cebdeeec3/nemo_automodel/components/_transformers/auto_model.py#L180
if cp_size > 1:
# Match Automodel's `get_train_context` in `cp_utils.py` where only
# flash and efficient backends are supported
# Ref: https://github.com/NVIDIA-NeMo/Automodel/blob/81788d6f4848f5f066c4a6a2bece4689a6a83687/nemo_automodel/components/distributed/cp_utils.py#L57
from torch.nn.attention import SDPBackend

sdpa_method = [
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
]
else:
sdpa_method = None

self.model = model_class.from_config(
model_config,
attn_implementation="flash_attention_2"
Expand All @@ -289,6 +303,7 @@ def __init__(
use_liger_kernel=False,
trust_remote_code=True,
torch_dtype=str(model_config.torch_dtype),
sdpa_method=sdpa_method,
)
if self.lora_enabled:
apply_lora_to_linear_modules(self.model, self.peft_config)
Expand All @@ -297,7 +312,6 @@ def __init__(
self.model.config.pad_token_id = tokenizer.pad_token_id

tp_size = self.cfg["dtensor_cfg"]["tensor_parallel_size"]
cp_size = self.cfg["dtensor_cfg"]["context_parallel_size"]
if cp_size > 1 and self.enable_seq_packing:
raise ValueError(
"Context parallel is not supported for sequence packing. Refer to https://github.com/NVIDIA/NeMo-RL/blob/main/docs/model-quirks.md#context-parallel-with-fsdp2 for more details."
Expand Down
Loading