diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 99ff7d260756..47ccd4fe240f 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1337,7 +1337,12 @@ def forward( past_key_values=past_key_values, position_ids=position_ids, ) - position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + # create position embeddings to be shared across the decoder layers + if self.config.use_mem_rope: + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + else: + position_embeddings = None all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index b56ebbd3895e..f11dc88886c5 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -1058,7 +1058,12 @@ def forward( past_key_values=past_key_values, position_ids=position_ids, ) - position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + # create position embeddings to be shared across the decoder layers + if self.config.use_mem_rope: + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + else: + position_embeddings = None all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None