From 79e4fa0703bdabafd693c0321e668604a236c354 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Wed, 23 Apr 2025 09:02:36 -0700 Subject: [PATCH] fix: use find_tied_parameters api from HF for tied weight keys Signed-off-by: Parth Chadha --- nemo_reinforcer/models/policy/dtensor_policy_worker.py | 4 ++-- nemo_reinforcer/models/policy/fsdp1_policy_worker.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo_reinforcer/models/policy/dtensor_policy_worker.py b/nemo_reinforcer/models/policy/dtensor_policy_worker.py index cf0f06bbfa..ac94d49120 100644 --- a/nemo_reinforcer/models/policy/dtensor_policy_worker.py +++ b/nemo_reinforcer/models/policy/dtensor_policy_worker.py @@ -25,7 +25,7 @@ FSDPModule, ) from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.modeling_utils import _get_tied_weight_keys +from transformers.integrations.accelerate import find_tied_parameters from nemo_reinforcer.models.dtensor.parallelize import _parallelize_model from nemo_reinforcer.algorithms.interfaces import LossFunction @@ -256,7 +256,7 @@ def train( mbs: Optional[int] = None, ) -> Dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" - num_tied_weights = len(_get_tied_weight_keys(self.model)) + num_tied_weights = len(find_tied_parameters(self.model)) skip_tie_check = os.environ.get("NRL_SKIP_TIED_WEIGHT_CHECK") if ( num_tied_weights != 0 diff --git a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py index 89b46fd6ac..5e8a8f6bc5 100644 --- a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py +++ b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py @@ -39,7 +39,7 @@ ) from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.modeling_utils import _get_tied_weight_keys +from transformers.integrations.accelerate import find_tied_parameters from nemo_reinforcer.models.policy import PolicyConfig from nemo_reinforcer.models.policy.utils import import_class_from_path from nemo_reinforcer.distributed.virtual_cluster import ( @@ -229,7 +229,7 @@ def train( ) -> Dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" # Check if the model has tied weights - num_tied_weights = len(_get_tied_weight_keys(self.model)) + num_tied_weights = len(find_tied_parameters(self.model)) skip_tie_check = os.environ.get("NRL_SKIP_TIED_WEIGHT_CHECK") if num_tied_weights != 0 and not skip_tie_check: raise ValueError(