diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index a98574551922..d0c1016e57d2 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1282,7 +1282,7 @@ def forward( (cache_position is not None and cache_position[0] == 0) or (past_key_values is None or past_key_values.get_seq_length() == 0) ) - if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: + if (prefill_compiled_stage or prefill_noncompiled_stage) or rope_deltas is None: position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, @@ -1290,18 +1290,20 @@ def forward( second_per_grid_ts=second_per_grid_ts, attention_mask=attention_mask, ) - self.rope_deltas = rope_deltas else: batch_size, seq_length, _ = inputs_embeds.shape position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1) if cache_position is not None: - delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + delta = (cache_position[0] + rope_deltas).to(inputs_embeds.device) else: delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device) delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1) position_ids = position_ids + delta.to(position_ids.device) + if rope_deltas is not None: + self.rope_deltas = rope_deltas + outputs = self.language_model( input_ids=None, position_ids=position_ids, @@ -1321,7 +1323,7 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - rope_deltas=self.rope_deltas, + rope_deltas=rope_deltas, ) return output if return_dict else output.to_tuple() @@ -1479,6 +1481,7 @@ def forward( pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, second_per_grid_ts=second_per_grid_ts, position_ids=position_ids, attention_mask=attention_mask, @@ -1561,7 +1564,7 @@ def prepare_inputs_for_generation( second_per_grid_ts=second_per_grid_ts, attention_mask=attention_mask, ) - self.model.rope_deltas = rope_deltas + model_inputs["rope_deltas"] = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids elif "position_ids" in model_inputs: batch_size, seq_length = model_inputs["position_ids"].shape diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 2a2ee775b7be..0753aad38826 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -606,7 +606,7 @@ def forward( (cache_position is not None and cache_position[0] == 0) or (past_key_values is None or past_key_values.get_seq_length() == 0) ) - if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: + if (prefill_compiled_stage or prefill_noncompiled_stage) or rope_deltas is None: position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, @@ -614,18 +614,20 @@ def forward( second_per_grid_ts=second_per_grid_ts, attention_mask=attention_mask, ) - self.rope_deltas = rope_deltas else: batch_size, seq_length, _ = inputs_embeds.shape position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1) if cache_position is not None: - delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + delta = (cache_position[0] + rope_deltas).to(inputs_embeds.device) else: delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device) delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1) position_ids = position_ids + delta.to(position_ids.device) + if rope_deltas is not None: + self.rope_deltas = rope_deltas + outputs = self.language_model( input_ids=None, position_ids=position_ids, @@ -645,7 +647,7 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - rope_deltas=self.rope_deltas, + rope_deltas=rope_deltas, ) return output if return_dict else output.to_tuple() @@ -735,6 +737,7 @@ def forward( pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, second_per_grid_ts=second_per_grid_ts, position_ids=position_ids, attention_mask=attention_mask, @@ -817,7 +820,7 @@ def prepare_inputs_for_generation( second_per_grid_ts=second_per_grid_ts, attention_mask=attention_mask, ) - self.model.rope_deltas = rope_deltas + model_inputs["rope_deltas"] = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids elif "position_ids" in model_inputs: batch_size, seq_length = model_inputs["position_ids"].shape