From 14c50770ce6693d7eb9d77812cb69adddedf7f7d Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 30 Jun 2025 15:39:15 -0700 Subject: [PATCH] fix gemma3 prefix Signed-off-by: ashors1 --- nemo_rl/models/dtensor/parallelize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index 664bc1a253..fb9c720c20 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -92,7 +92,7 @@ def _parallelize_gemma3( Tensor parallelism is not supported for Gemma3 models because of tied word embeddings. """ if isinstance(model, Gemma3ForConditionalGeneration): - model_prefix = "language_model.model" + model_prefix = "language_model" else: model_prefix = "model" @@ -399,7 +399,7 @@ def _parallelize_model( """ model_cls = type(model) if model_cls == Gemma3ForConditionalGeneration: - layers: torch.nn.ModuleList = model.language_model.model.layers # type: ignore + layers: torch.nn.ModuleList = model.language_model.layers # type: ignore num_attention_heads = model.config.text_config.num_attention_heads num_key_value_heads = model.config.text_config.num_key_value_heads else: