From ebb66effad9eb1092cde4a9f5b8f90b91cc12b75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CElla?= Date: Mon, 9 Mar 2026 18:46:31 +0100 Subject: [PATCH 1/2] only call rotary_emb when config.use_mem_rope is True --- src/transformers/models/zamba2/modeling_zamba2.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 99ff7d260756..47ccd4fe240f 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1337,7 +1337,12 @@ def forward( past_key_values=past_key_values, position_ids=position_ids, ) - position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + # create position embeddings to be shared across the decoder layers + if self.config.use_mem_rope: + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + else: + position_embeddings = None all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None From 20cae8432201fa4866d0d0403cacfce65baf9b99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CElla?= Date: Mon, 9 Mar 2026 19:25:22 +0100 Subject: [PATCH 2/2] add fix in modular --- src/transformers/models/zamba2/modular_zamba2.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index b56ebbd3895e..f11dc88886c5 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -1058,7 +1058,12 @@ def forward( past_key_values=past_key_values, position_ids=position_ids, ) - position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + # create position embeddings to be shared across the decoder layers + if self.config.use_mem_rope: + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + else: + position_embeddings = None all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None