From de74e15b04153e4a58b5b4a10730409c5a7a8e98 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Thu, 13 Nov 2025 16:14:09 +0800 Subject: [PATCH 01/19] init --- src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 1 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 1 + .../models/paddleocr_vl/__init__.py | 28 + .../configuration_paddleocr_vl.py | 205 ++ .../image_processing_paddleocr_vl.py | 499 +++++ .../paddleocr_vl/modeling_paddleocr_vl.py | 1659 ++++++++++++++ .../paddleocr_vl/modular_paddleocr_vl.py | 1933 +++++++++++++++++ .../paddleocr_vl/processing_paddleocr_vl.py | 152 ++ 12 files changed, 4483 insertions(+) create mode 100644 src/transformers/models/paddleocr_vl/__init__.py create mode 100644 src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py create mode 100644 src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py create mode 100644 src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py create mode 100644 src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py create mode 100644 src/transformers/models/paddleocr_vl/processing_paddleocr_vl.py diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 3534ce6719d0..fcb759682ea2 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -257,6 +257,7 @@ from .ovis2 import * from .owlv2 import * from .owlvit import * + from .paddleocr_vl import * from .paligemma import * from .parakeet import * from .patchtsmixer import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 9a3b2ec5ecc2..b0f489b02af8 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -302,6 +302,7 @@ ("ovis2", "Ovis2Config"), ("owlv2", "Owlv2Config"), ("owlvit", "OwlViTConfig"), + ("paddleocr_vl", "PaddleOCRVLConfig"), ("paligemma", "PaliGemmaConfig"), ("parakeet_ctc", "ParakeetCTCConfig"), ("parakeet_encoder", "ParakeetEncoderConfig"), @@ -761,6 +762,7 @@ ("ovis2", "Ovis2"), ("owlv2", "OWLv2"), ("owlvit", "OWL-ViT"), + ("paddleocr_vl", "PaddleOCRVL"), ("paligemma", "PaliGemma"), ("parakeet", "Parakeet"), ("parakeet_ctc", "Parakeet"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 5ed14d4f3b1d..483c41957d6a 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -154,6 +154,7 @@ ("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")), ("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")), ("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")), + ("paddleocr_vl", ("PaddleOCRVLImageProcessor", "PaddleOCRVLImageProcessorFast")), ("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")), ("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")), ("perception_lm", (None, "PerceptionLMImageProcessorFast")), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 257fb95fdea7..9d27faf8945c 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1052,6 +1052,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("mistral3", "Mistral3ForConditionalGeneration"), ("mllama", "MllamaForConditionalGeneration"), ("ovis2", "Ovis2ForConditionalGeneration"), + ("paddleocr_vl", "PaddleOCRVLForConditionalGeneration"), ("paligemma", "PaliGemmaForConditionalGeneration"), ("perception_lm", "PerceptionLMForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 691e4afc96e7..9c89881ed459 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -113,6 +113,7 @@ ("ovis2", "Ovis2Processor"), ("owlv2", "Owlv2Processor"), ("owlvit", "OwlViTProcessor"), + ("paddleocr_vl", "PaddleOCRVLProcessor"), ("paligemma", "PaliGemmaProcessor"), ("perception_lm", "PerceptionLMProcessor"), ("phi4_multimodal", "Phi4MultimodalProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 65a3885a1c46..fea362224a81 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -522,6 +522,7 @@ ("ovis2", (None, "Qwen2TokenizerFast" if is_tokenizers_available() else None)), ("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), + ("paddleocr_vl", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("parakeet", (None, "ParakeetTokenizerFast" if is_tokenizers_available() else None)), ( diff --git a/src/transformers/models/paddleocr_vl/__init__.py b/src/transformers/models/paddleocr_vl/__init__.py new file mode 100644 index 000000000000..6deb0275de89 --- /dev/null +++ b/src/transformers/models/paddleocr_vl/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_paddleocr_vl import * + from .modeling_paddleocr_vl import * + from .processing_paddleocr_vl import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py new file mode 100644 index 000000000000..dc72f06943c3 --- /dev/null +++ b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py @@ -0,0 +1,205 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_paddleocr_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class PaddleOCRVisionConfig(PretrainedConfig): + model_type = "paddleocr_vl" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=14, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + spatial_merge_size=2, + temporal_patch_size=2, + tokens_per_second=2, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.tokens_per_second = tokens_per_second + + +class PaddleOCRVLConfig(PretrainedConfig): + """ + Configuration class. + + This class stores the configuration of an Ernie model, defining the model architecture. + It inherits from PretrainedConfig and can be used to control model outputs. + """ + + model_type = "paddleocr_vl" + keys_to_ignore_at_inference = ["past_key_values"] + sub_configs = {"vision_config": PaddleOCRVisionConfig} + + # Default tensor parallel plan for base model `Qwen3` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=11008, + max_position_embeddings=32768, + num_hidden_layers=2, + num_attention_heads=2, + image_token_id=101304, + video_token_id=101305, + vision_start_token_id=101306, + rms_norm_eps=1e-6, + use_cache=False, + use_flash_attention=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + head_dim=128, + hidden_act="silu", + use_bias=False, + rope_theta=10000, + weight_share_add_bias=True, + ignored_index=-100, + attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.0, + compression_ratio: float = 1.0, + num_key_value_heads=None, + max_sequence_length=None, + tie_word_embeddings=False, + vision_config=None, + rope_scaling=None, + **kwargs, + ): + """ + Initialize configuration with default or specified parameters. + + Args: + vocab_size (int): Size of the vocabulary (number of unique tokens) + hidden_size (int): Dimensionality of the encoder layers and the pooler layer + intermediate_size (int): Dimensionality of the "intermediate" (feed-forward) layer + max_position_embeddings (int): Maximum sequence length the model can handle + num_hidden_layers (int): Number of hidden layers in the Transformer encoder + num_attention_heads (int): Number of attention heads for each attention layer + rms_norm_eps (float): The epsilon used by the RMS normalization layers + use_cache (bool): Whether to use caching for faster generation (decoding) + use_flash_attention (bool): Whether to use FlashAttention for optimized attention computation + pad_token_id (int): Token ID used for padding sequences + bos_token_id (int): Token ID used for beginning-of-sequence + eos_token_id (int): Token ID used for end-of-sequence + use_bias (bool): Whether to use bias terms in linear layers + rope_theta (float): The base period of the RoPE embeddings + weight_share_add_bias (bool): Whether to share bias weights in certain layers + ignored_index (int): Target value that is ignored during loss computation + attention_probs_dropout_prob (float): Dropout probability for attention weights + hidden_dropout_prob (float): Dropout probability for hidden layers + compression_ratio (float): Ratio for KV cache compression (1.0 = no compression) + num_key_value_heads (int): Number of key/value heads (for Grouped Query Attention) + max_sequence_length (int): Maximum sequence length for positional embeddings + **kwargs: Additional keyword arguments passed to parent class + """ + + # Set default for tied embeddings if not specified. + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.use_flash_attention = use_flash_attention + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.head_dim = head_dim + self.hidden_act = hidden_act + self.sliding_window = None + self.hidden_size = hidden_size + self.use_bias = use_bias + self.weight_share_add_bias = weight_share_add_bias + self.rope_theta = rope_theta + self.ignored_index = ignored_index + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.hidden_dropout_prob = hidden_dropout_prob + self.compression_ratio = compression_ratio + self.num_key_value_heads = num_key_value_heads + self.max_sequence_length = max_sequence_length + self.rope_scaling = rope_scaling + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self, ignore_keys={"mrope_section"}) + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +__all__ = ["PaddleOCRVLConfig", "PaddleOCRVisionConfig"] diff --git a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py new file mode 100644 index 000000000000..d2b7dad85fa1 --- /dev/null +++ b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py @@ -0,0 +1,499 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_paddleocr_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...processing_utils import ImagesKwargs +from ...utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +class PaddleOCRVLImageProcessorKwargs(ImagesKwargs, total=False): + r""" + min_pixels (`int`, *optional*, defaults to `56 * 56`): + The min pixels of the image to resize the image. + max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`): + The max pixels of the image to resize the image. + patch_size (`int`, *optional*, defaults to 14): + The spatial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to 2): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. + """ + + min_pixels: int + max_pixels: int + patch_size: int + temporal_patch_size: int + merge_size: int + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 28 * 28 * 130, + max_pixels: int = 28 * 28 * 1280, +): + if height < factor: + width = round((width * factor) / height) + height = factor + + if width < factor: + height = round((height * factor) / width) + width = factor + + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +class PaddleOCRVLImageProcessor(BaseImageProcessor): + r""" + Constructs a PaddleOCRVL image processor that dynamically resizes images based on the original images. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. + size (`Dict[str, int]`, *optional*, defaults to `None`): + Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + min_pixels (`int`, *optional*, defaults to `28 * 28 * 130`): + The min pixels of the image to resize the image. + max_pixels (`int`, *optional*, defaults to `28 * 28 * 1670`): + The max pixels of the image to resize the image. + patch_size (`int`, *optional*, defaults to 14): + The spacial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to 2): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. + """ + + model_input_names = [ + "pixel_values", + "image_grid_thw", + ] + valid_kwargs = PaddleOCRVLImageProcessorKwargs + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: bool = True, + min_pixels: int = 28 * 28 * 130, + max_pixels: int = 28 * 28 * 1280, + patch_size: int = 14, + temporal_patch_size: int = 1, + merge_size: int = 2, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if size is not None and ("shortest_edge" not in size or "longest_edge" not in size): + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + else: + size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280} + # backward compatibility: override size with min_pixels and max_pixels if they are provided + if min_pixels is not None: + size["shortest_edge"] = min_pixels + if max_pixels is not None: + size["longest_edge"] = max_pixels + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.size = size + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.merge_size = merge_size + self.do_convert_rgb = do_convert_rgb + + def _preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + patch_size: Optional[int] = None, + temporal_patch_size: Optional[int] = None, + merge_size: Optional[int] = None, + do_convert_rgb: Optional[bool] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + patch_size (`int`, *optional*, defaults to `self.patch_size`): + The spatial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to `self.merge_size`): + The merge size of the vision encoder to llm encoder. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + images = make_list_of_images(images) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + height, width = get_image_size(images[0], channel_dim=input_data_format) + resized_height, resized_width = height, width + processed_images = [] + + for image in images: + if do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=self.patch_size * self.merge_size, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + image = resize( + image, + size=(resized_height, resized_width), + resample=resample, + input_data_format=input_data_format, + ) + + if do_rescale: + image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, + mean=image_mean, + std=image_std, + input_data_format=input_data_format, + ) + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + processed_images.append(image) + + patches = np.array(processed_images) + if data_format == ChannelDimension.LAST: + patches = patches.transpose(0, 3, 1, 2) + if patches.shape[0] == 1: + patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1)) + + channel = patches.shape[1] + grid_t = patches.shape[0] // self.temporal_patch_size + grid_h, grid_w = ( + resized_height // self.patch_size, + resized_width // self.patch_size, + ) + patches = patches.reshape( + grid_t, + self.temporal_patch_size, + channel, + grid_h, + self.patch_size, + grid_w, + self.patch_size, + ) + patches = patches.transpose(0, 3, 5, 2, 1, 4, 6) + assert self.temporal_patch_size == 1 + flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel, self.patch_size, self.patch_size) + return flatten_patches, (grid_t, grid_h, grid_w) + + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + patch_size: Optional[int] = None, + temporal_patch_size: Optional[int] = None, + merge_size: Optional[int] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + min_pixels (`int`, *optional*, defaults to `self.min_pixels`): + The min pixels of the image to resize the image. + max_pixels (`int`, *optional*, defaults to `self.max_pixels`): + The max pixels of the image to resize the image. + patch_size (`int`, *optional*, defaults to `self.patch_size`): + The spatial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to `self.merge_size`): + The merge size of the vision encoder to llm encoder. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + """ + min_pixels = min_pixels if min_pixels is not None else self.min_pixels + max_pixels = max_pixels if max_pixels is not None else self.max_pixels + + if size is not None: + if "shortest_edge" not in size or "longest_edge" not in size: + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + min_pixels = size["shortest_edge"] + elif min_pixels is not None and max_pixels is not None: + # backward compatibility: override size with min_pixels and max_pixels if they are provided + size = {"shortest_edge": min_pixels, "longest_edge": max_pixels} + else: + size = {**self.size} + + do_resize = do_resize if do_resize is not None else self.do_resize + + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + patch_size = patch_size if patch_size is not None else self.patch_size + temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self.temporal_patch_size + merge_size = merge_size if merge_size is not None else self.merge_size + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + if images is not None: + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if images is not None and not valid_images(images): + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") + + validate_preprocess_arguments( + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + data = {} + pixel_values, vision_grid_thws = [], [] + for image in images: + patches, image_grid_thw = self._preprocess( + image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + merge_size=merge_size, + data_format=data_format, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + ) + pixel_values.extend(patches) + vision_grid_thws.append(image_grid_thw) + pixel_values = np.array(pixel_values) + vision_grid_thws = np.array(vision_grid_thws) + data.update({"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}) + + return BatchFeature(data=data, tensor_type=return_tensors) + + def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None): + """ + A utility that returns number of image patches for a given image size. + + Args: + height (`int`): + Height of the input image. + width (`int`): + Width of the input image. + images_kwargs (`dict`, *optional*) + Any kwargs to override defaults of the image processor. + Returns: + `int`: Number of image patches per image. + """ + min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"] + max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"] + patch_size = images_kwargs.get("patch_size", self.patch_size) + merge_size = images_kwargs.get("merge_size", self.merge_size) + + factor = patch_size * merge_size + resized_height, resized_width = smart_resize( + height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels + ) + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + return grid_h * grid_w + + +__all__ = ["PaddleOCRVLImageProcessor"] diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py new file mode 100644 index 000000000000..1585295c0c33 --- /dev/null +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -0,0 +1,1659 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_paddleocr_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Optional, Union + +import numpy as np +import torch +from einops import rearrange +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.nn.init import _calculate_fan_in_and_fan_out + +from ...activations import ACT2FN, GELUActivation +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, logging, torch_int +from ...utils.generic import check_model_inputs +from .configuration_paddleocr_vl import PaddleOCRVisionConfig, PaddleOCRVLConfig + + +logger = logging.get_logger(__name__) + + +class Projector(nn.Module): + def __init__(self, text_config: PaddleOCRVLConfig, vision_config: PaddleOCRVisionConfig): + super().__init__() + self.text_config = text_config + self.vision_config = vision_config + self.merge_kernel_size = (2, 2) + + self.hidden_size = self.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1] + + self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05) + self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.act = GELUActivation() + self.linear_2 = nn.Linear(self.hidden_size, self.text_config.hidden_size, bias=True) + + def forward(self, image_features: torch.Tensor, image_grid_thw: list[tuple[int, int, int]]) -> torch.Tensor: + m1, m2 = self.merge_kernel_size + + processed_features = [] + for image_feature, image_grid in zip(image_features, image_grid_thw): + image_feature = self.pre_norm(image_feature) + t, h, w = image_grid + + image_feature = rearrange( + image_feature, "(t h p1 w p2) d -> (t h w) (p1 p2 d)", t=t, h=h // m1, p1=m1, w=w // m2, p2=m2 + ) + hidden_states = self.linear_1(image_feature) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + processed_features.append(hidden_states) + + return processed_features + + +class PaddleOCRRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Ernie4_5RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: PaddleOCRVLConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: Optional[PaddleOCRVLConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Ernie4_5MLP(nn.Module): + def __init__(self, config: PaddleOCRVLConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Ernie4_5Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: PaddleOCRVLConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + + self.attention_dropout = 0.0 + self.is_causal = True + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + self.rope_scaling = config.rope_scaling + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if "position_ids" in kwargs and kwargs["position_ids"] is not None: + position_ids = kwargs["position_ids"] + if position_ids.dim() == 3 and position_ids.shape[0] > 1: + kwargs["position_ids"] = position_ids[0:1] + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("RMSNorm") +class Ernie4_5RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Ernie4_5RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Ernie4_5DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: PaddleOCRVLConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Ernie4_5Attention(config=config, layer_idx=layer_idx) + + self.mlp = Ernie4_5MLP(config) + self.input_layernorm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class Ernie4_5PreTrainedModel(PreTrainedModel): + config: PaddleOCRVLConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Ernie4_5DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Ernie4_5DecoderLayer, + "attentions": Ernie4_5Attention, + } + + +@auto_docstring +class Ernie4_5Model(Ernie4_5PreTrainedModel): + def __init__(self, config: PaddleOCRVLConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Ernie4_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Ernie4_5RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_embeddings=position_embeddings, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + logger.warning_once( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@auto_docstring +class PaddleOCRPreTrainedModel(PreTrainedModel): + config_class = PaddleOCRVLConfig + base_model_prefix = "PaddleOCR" + supports_gradient_checkpointing = True + + _no_split_modules = [ + "PaddleOCRTextEmbeddings", + "PaddleOCREncoderLayer", + "PaddleOCRVisionEmbeddings", + "PaddleOCRMultiheadAttentionPoolingHead", + ] + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, PaddleOCRVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, PaddleOCRVLConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, PaddleOCRAttention): + nn.init.xavier_uniform_(module.q_proj.weight) + nn.init.xavier_uniform_(module.k_proj.weight) + nn.init.xavier_uniform_(module.v_proj.weight) + nn.init.xavier_uniform_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, PaddleOCRMLP): + nn.init.xavier_uniform_(module.fc1.weight) + nn.init.xavier_uniform_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, PaddleOCRMultiheadAttentionPoolingHead): + nn.init.xavier_uniform_(module.probe.data) + nn.init.xavier_uniform_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@auto_docstring( + custom_intro=""" + The vision model from PaddleOCR without any head or projection on top. + """ +) +class PaddleOCRVisionModel(PaddleOCRPreTrainedModel): + config: PaddleOCRVisionConfig + main_input_name = "pixel_values" + input_modalities = "image" + + def __init__(self, config: PaddleOCRVisionConfig): + super().__init__(config) + + self.vision_model = PaddleOCRVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values, + position_ids: Optional[torch.Tensor] = None, + image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, + ) -> BaseModelOutputWithPooling: + """ + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): + The tensors corresponding to the input images. + position_ids (`torch.LongTensor` of shape `sequence_length`): + The position ids of the image. + image_grid_thw (`List[Tuple]`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + return self.vision_model( + pixel_values=pixel_values, + position_ids=position_ids, + image_grid_thw=image_grid_thw, + ) + + +class PaddleOCRVisionEmbeddings(nn.Module): + def __init__(self, config: PaddleOCRVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + self.packing_position_embedding = nn.Embedding(32768, self.embed_dim) + + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int, is_after_patchify: bool = False + ) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and no class embeddings. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + num_positions = self.position_embedding.weight.shape[0] + + patch_pos_embed = self.position_embedding.weight.unsqueeze(0) + + dim = embeddings.shape[-1] + + if is_after_patchify: + new_height = height + new_width = width + else: + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bilinear", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward( + self, + pixel_values: torch.FloatTensor, + position_ids: Optional[torch.Tensor] = None, + image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, + ) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): + The tensors corresponding to the input images. + position_ids (`torch.LongTensor` of shape `sequence_length`): + The position ids of the image. + image_grid_thw (`List[Tuple]`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + if pixel_values.dim() == 5: + assert position_ids is not None + + batch_size, squence_len, channel, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w") + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(-2).squeeze(-1) + embeddings = rearrange(embeddings, "(b l) d -> b l d", b=batch_size, l=squence_len) + + if image_grid_thw is not None: + flatten_image_grid_thw = self.flatten_list(image_grid_thw) + assert batch_size == 1 + start = 0 + assert sum([np.prod(x) for x in flatten_image_grid_thw]) == embeddings.shape[1], ( + flatten_image_grid_thw, + embeddings.shape, + ) + embeddings = embeddings.squeeze(0) + tmp_embeddings = [] + for image_grid in image_grid_thw: + t, h, w = image_grid + end = start + t * h * w + image_embeddings = embeddings[start:end, :] + position_embedding = ( + self.interpolate_pos_encoding(image_embeddings, h, w, True).squeeze(0).repeat(t, 1) + ) + image_embeddings = image_embeddings + position_embedding + tmp_embeddings.append(image_embeddings) + start = end + embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0) + else: + embeddings = embeddings + self.packing_position_embedding(position_ids) + return embeddings + else: + raise NotImplementedError(str(pixel_values.shape)) + + @staticmethod + def flatten_list(image_grid_thw): + tmp_image_grid_thw = [] + for image_grid in image_grid_thw: + if isinstance(image_grid, list): + tmp_image_grid_thw.extend(image_grid) + else: + tmp_image_grid_thw.append(image_grid) + return tmp_image_grid_thw + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class PaddleOCRAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: PaddleOCRVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + cos, sin = rope_emb + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + queries, keys = apply_rotary_pos_emb_vision(queries, keys, cos, sin) + queries = queries.transpose(1, 2) + keys = keys.transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class PaddleOCRMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class PaddleOCRMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: PaddleOCRVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = PaddleOCRMLP(config) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class PaddleOCREncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: PaddleOCRVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = PaddleOCRAttention(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = PaddleOCRMLP(config) + + @auto_docstring + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class PaddleOCREncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`PaddleOCREncoderLayer`]. + + Args: + config: PaddleOCRConfig + """ + + def __init__(self, config: PaddleOCRVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([PaddleOCREncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + embed_dim = config.hidden_size + num_heads = config.num_attention_heads + head_dim = embed_dim // num_heads + self.rotary_pos_emb = PaddleOCRRotaryEmbedding(head_dim // 2) + + # Ignore copy + @auto_docstring + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, + ) -> BaseModelOutput: + """ + Args: + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + The attention_mask used in forward function shape [batch_size X sequence_length] if not None. + image_grid_thw (`List[Tuple]`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + device = inputs_embeds.device + hidden_states = inputs_embeds + attention_mask = attention_mask.to(inputs_embeds.dtype) if attention_mask is not None else None + flatten_image_grid_thw = self.flatten_list(image_grid_thw) + assert sum([np.prod(x) for x in flatten_image_grid_thw]) == hidden_states.shape[1], ( + flatten_image_grid_thw, + hidden_states.shape, + ) + + split_hids = [] + split_wids = [] + for t, h, w in flatten_image_grid_thw: + image_pids = torch.arange(t * h * w, device=device) % (h * w) + sample_hids = image_pids // w + sample_wids = image_pids % w + split_hids.append(sample_hids) + split_wids.append(sample_wids) + width_position_ids = torch.concat(split_wids, dim=0) + height_position_ids = torch.concat(split_hids, dim=0) + + pids = torch.stack([height_position_ids, width_position_ids], dim=-1) + max_grid_size = pids.max() + 1 + rope_emb_max_grid = self.rotary_pos_emb(max_grid_size) + rope_emb = rope_emb_max_grid[pids].flatten(1) + rope_emb = rope_emb.repeat(1, 2) + rope_emb = (rope_emb.cos(), rope_emb.sin()) + + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + rope_emb=rope_emb, + ) + + return BaseModelOutput( + last_hidden_state=hidden_states, + ) + + @staticmethod + def flatten_list(image_grid_thw): + tmp_image_grid_thw = [] + for image_grid in image_grid_thw: + if isinstance(image_grid, list): + tmp_image_grid_thw.extend(image_grid) + else: + tmp_image_grid_thw.append(image_grid) + return tmp_image_grid_thw + + +class PaddleOCRVisionTransformer(PaddleOCRPreTrainedModel): + _can_record_outputs = { + "hidden_states": PaddleOCREncoderLayer, + "attentions": PaddleOCRAttention, + } + + def __init__(self, config: PaddleOCRVisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.embeddings = PaddleOCRVisionEmbeddings(config) + self.encoder = PaddleOCREncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = PaddleOCRMultiheadAttentionPoolingHead(config) + + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, + ) -> BaseModelOutputWithPooling: + """ + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`): + The tensors corresponding to the input images. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function shape [batch_size X sequence_length] if not None. + position_ids (`torch.LongTensor` of shape `sequence_length`): + The position ids of the image. + image_grid_thw (`List[Tuple]`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + hidden_states = self.embeddings(pixel_values, position_ids=position_ids, image_grid_thw=image_grid_thw) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + image_grid_thw=image_grid_thw, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=None, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@dataclass +class PaddleOCRVLCausalLMOutputWithPast(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[list[torch.FloatTensor]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class PaddleOCRVLForConditionalGeneration(Ernie4_5PreTrainedModel, GenerationMixin): + def __init__(self, config): + super().__init__(config) + self.mlp_AR = Projector(config, config.vision_config) + self.visual = PaddleOCRVisionModel(config.vision_config) + self.model = Ernie4_5Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.rope_deltas = None + + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + if torch.is_tensor(second_per_grid_t): + second_per_grid_t = second_per_grid_t.detach().item() + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[tuple, PaddleOCRVLCausalLMOutputWithPast]: + r""" + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + pixel_values = pixel_values.unsqueeze(0) + ppocr_position_ids = [] + image_grid_hws = [] + + for idx, thw in enumerate(image_grid_thw): + thw_tuple = tuple(thw.detach().cpu().numpy().tolist()) + numel = np.prod(thw_tuple) + image_grid_hws.append(thw_tuple) + image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) + ppocr_position_ids.append(image_position_ids) + + ppocr_position_ids = torch.concat(ppocr_position_ids, dim=0).to(pixel_values.device) + + vision_outputs = self.visual( + pixel_values=pixel_values, + image_grid_thw=image_grid_hws, + position_ids=ppocr_position_ids, + ) + image_embeds = vision_outputs.last_hidden_state + + image_embeds = self.mlp_AR(image_embeds, image_grid_thw) + + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + # image_embeds is a list of tensor, each tensor is a image feature,I want to concat them all into a tensor + image_embeds = torch.cat(image_embeds, dim=0) + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + # position_ids = None + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + return_dict=return_dict, + **kwargs, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return PaddleOCRVLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_cache=use_cache, + **kwargs, + ) + + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = [ + "pixel_values", + "image_grid_thw", + "pixel_values_videos", + "video_grid_thw", + "second_per_grid_ts", + ] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + # input_ids is required for expanding visual inputs + # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = ["PaddleOCRVLForConditionalGeneration"] diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py new file mode 100644 index 000000000000..e0935c4f57f9 --- /dev/null +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -0,0 +1,1933 @@ +# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Optional, Union + +import numpy as np +import torch +from einops import rearrange +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.nn.init import _calculate_fan_in_and_fan_out + +from ...activations import GELUActivation +from ...cache_utils import Cache +from ...configuration_utils import PretrainedConfig +from ...generation import GenerationMixin +from ...image_processing_utils import BatchFeature +from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, +) +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput +from ...modeling_rope_utils import dynamic_rope_update, rope_config_validation +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor +from ...processing_utils import ( + ProcessingKwargs, + ProcessorMixin, + Unpack, +) +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import TransformersKwargs, auto_docstring, logging, torch_int +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaMLP, + LlamaModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) +from ..siglip.modeling_siglip import ( + SiglipAttention, + SiglipEncoder, + SiglipEncoderLayer, + SiglipMLP, + SiglipMultiheadAttentionPoolingHead, + SiglipVisionEmbeddings, + SiglipVisionModel, + SiglipVisionTransformer, +) + + +logger = logging.get_logger(__name__) + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 28 * 28 * 130, + max_pixels: int = 28 * 28 * 1280, +): + if height < factor: + width = round((width * factor) / height) + height = factor + + if width < factor: + height = round((height * factor) / width) + width = factor + + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +class PaddleOCRVLImageProcessor(Qwen2VLImageProcessor): + r""" + Constructs a PaddleOCRVL image processor that dynamically resizes images based on the original images. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. + size (`Dict[str, int]`, *optional*, defaults to `None`): + Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + min_pixels (`int`, *optional*, defaults to `28 * 28 * 130`): + The min pixels of the image to resize the image. + max_pixels (`int`, *optional*, defaults to `28 * 28 * 1670`): + The max pixels of the image to resize the image. + patch_size (`int`, *optional*, defaults to 14): + The spacial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to 2): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. + """ + + model_input_names = [ + "pixel_values", + "image_grid_thw", + ] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: bool = True, + min_pixels: int = 28 * 28 * 130, + max_pixels: int = 28 * 28 * 1280, + patch_size: int = 14, + temporal_patch_size: int = 1, + merge_size: int = 2, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.merge_size = merge_size + self.do_convert_rgb = do_convert_rgb + self.size = size + + def _preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + patch_size: Optional[int] = None, + temporal_patch_size: Optional[int] = None, + merge_size: Optional[int] = None, + do_convert_rgb: Optional[bool] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + patch_size (`int`, *optional*, defaults to `self.patch_size`): + The spatial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to `self.merge_size`): + The merge size of the vision encoder to llm encoder. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + images = make_list_of_images(images) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + height, width = get_image_size(images[0], channel_dim=input_data_format) + resized_height, resized_width = height, width + processed_images = [] + + for image in images: + if do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=self.patch_size * self.merge_size, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + image = resize( + image, + size=(resized_height, resized_width), + resample=resample, + input_data_format=input_data_format, + ) + + if do_rescale: + image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, + mean=image_mean, + std=image_std, + input_data_format=input_data_format, + ) + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + processed_images.append(image) + + patches = np.array(processed_images) + if data_format == ChannelDimension.LAST: + patches = patches.transpose(0, 3, 1, 2) + if patches.shape[0] == 1: + patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1)) + + channel = patches.shape[1] + grid_t = patches.shape[0] // self.temporal_patch_size + grid_h, grid_w = ( + resized_height // self.patch_size, + resized_width // self.patch_size, + ) + patches = patches.reshape( + grid_t, + self.temporal_patch_size, + channel, + grid_h, + self.patch_size, + grid_w, + self.patch_size, + ) + patches = patches.transpose(0, 3, 5, 2, 1, 4, 6) + assert self.temporal_patch_size == 1 + flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel, self.patch_size, self.patch_size) + return flatten_patches, (grid_t, grid_h, grid_w) + + +class PaddleOCRVLProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + } + + +class PaddleOCRVLProcessor(ProcessorMixin): + r""" + [`PaddleOCRVLProcessor`] offers all the functionalities of [`PaddleOCRVLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~PaddleOCRVLProcessor.__call__`] and [`~PaddleOCRVLProcessor.decode`] for more information. + Args: + image_processor ([`PaddleOCRVLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = [ + "chat_template", + "image_std", + "min_pixels", + "image_mean", + "merge_size", + "image_processor_type", + "temporal_patch_size", + "patch_size", + "max_pixels", + ] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = "<|IMAGE_PLACEHOLDER|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + **kwargs: Unpack[PaddleOCRVLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + PaddleOCRVLImageProcessor's [`~PaddleOCRVLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + PaddleOCRVLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if images is not None: + image_inputs = self.image_processor(images=images, return_tensors="pt") + image_inputs["pixel_values"] = image_inputs["pixel_values"] + image_grid_thw = image_inputs["image_grid_thw"] + + else: + image_inputs = {} + image_grid_thw = None + + if not isinstance(text, list): + text = [text] + + if image_grid_thw is not None: + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + text[i] = text[i].replace( + self.image_token, + "<|placeholder|>" + * ( + image_grid_thw[index].prod() + // self.image_processor.merge_size + // self.image_processor.merge_size + ), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs}) + + +class PaddleOCRVisionConfig(PretrainedConfig): + model_type = "paddleocr_vl" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=14, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + spatial_merge_size=2, + temporal_patch_size=2, + tokens_per_second=2, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.tokens_per_second = tokens_per_second + + +class PaddleOCRVLConfig(PretrainedConfig): + """ + Configuration class. + + This class stores the configuration of an Ernie model, defining the model architecture. + It inherits from PretrainedConfig and can be used to control model outputs. + """ + + model_type = "paddleocr_vl" + keys_to_ignore_at_inference = ["past_key_values"] + sub_configs = {"vision_config": PaddleOCRVisionConfig} + + # Default tensor parallel plan for base model `Qwen3` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=11008, + max_position_embeddings=32768, + num_hidden_layers=2, + num_attention_heads=2, + image_token_id=101304, + video_token_id=101305, + vision_start_token_id=101306, + rms_norm_eps=1e-6, + use_cache=False, + use_flash_attention=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + head_dim=128, + hidden_act="silu", + use_bias=False, + rope_theta=10000, + weight_share_add_bias=True, + ignored_index=-100, + attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.0, + compression_ratio: float = 1.0, + num_key_value_heads=None, + max_sequence_length=None, + tie_word_embeddings=False, + vision_config=None, + rope_scaling=None, + **kwargs, + ): + """ + Initialize configuration with default or specified parameters. + + Args: + vocab_size (int): Size of the vocabulary (number of unique tokens) + hidden_size (int): Dimensionality of the encoder layers and the pooler layer + intermediate_size (int): Dimensionality of the "intermediate" (feed-forward) layer + max_position_embeddings (int): Maximum sequence length the model can handle + num_hidden_layers (int): Number of hidden layers in the Transformer encoder + num_attention_heads (int): Number of attention heads for each attention layer + rms_norm_eps (float): The epsilon used by the RMS normalization layers + use_cache (bool): Whether to use caching for faster generation (decoding) + use_flash_attention (bool): Whether to use FlashAttention for optimized attention computation + pad_token_id (int): Token ID used for padding sequences + bos_token_id (int): Token ID used for beginning-of-sequence + eos_token_id (int): Token ID used for end-of-sequence + use_bias (bool): Whether to use bias terms in linear layers + rope_theta (float): The base period of the RoPE embeddings + weight_share_add_bias (bool): Whether to share bias weights in certain layers + ignored_index (int): Target value that is ignored during loss computation + attention_probs_dropout_prob (float): Dropout probability for attention weights + hidden_dropout_prob (float): Dropout probability for hidden layers + compression_ratio (float): Ratio for KV cache compression (1.0 = no compression) + num_key_value_heads (int): Number of key/value heads (for Grouped Query Attention) + max_sequence_length (int): Maximum sequence length for positional embeddings + **kwargs: Additional keyword arguments passed to parent class + """ + + # Set default for tied embeddings if not specified. + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.use_flash_attention = use_flash_attention + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.head_dim = head_dim + self.hidden_act = hidden_act + self.sliding_window = None + self.hidden_size = hidden_size + self.use_bias = use_bias + self.weight_share_add_bias = weight_share_add_bias + self.rope_theta = rope_theta + self.ignored_index = ignored_index + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.hidden_dropout_prob = hidden_dropout_prob + self.compression_ratio = compression_ratio + self.num_key_value_heads = num_key_value_heads + self.max_sequence_length = max_sequence_length + self.rope_scaling = rope_scaling + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self, ignore_keys={"mrope_section"}) + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Projector(nn.Module): + def __init__(self, text_config: PaddleOCRVLConfig, vision_config: PaddleOCRVisionConfig): + super().__init__() + self.text_config = text_config + self.vision_config = vision_config + self.merge_kernel_size = (2, 2) + + self.hidden_size = self.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1] + + self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05) + self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.act = GELUActivation() + self.linear_2 = nn.Linear(self.hidden_size, self.text_config.hidden_size, bias=True) + + def forward(self, image_features: torch.Tensor, image_grid_thw: list[tuple[int, int, int]]) -> torch.Tensor: + m1, m2 = self.merge_kernel_size + + processed_features = [] + for image_feature, image_grid in zip(image_features, image_grid_thw): + image_feature = self.pre_norm(image_feature) + t, h, w = image_grid + + image_feature = rearrange( + image_feature, "(t h p1 w p2) d -> (t h w) (p1 p2 d)", t=t, h=h // m1, p1=m1, w=w // m2, p2=m2 + ) + hidden_states = self.linear_1(image_feature) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + processed_features.append(hidden_states) + + return processed_features + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class PaddleOCRRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Ernie4_5RotaryEmbedding(LlamaRotaryEmbedding): + def __init__(self, config: PaddleOCRVLConfig, device=None): + super().__init__() + + @staticmethod + def compute_default_rope_parameters( + config: Optional[PaddleOCRVLConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Ernie4_5MLP(LlamaMLP): + def __init__(self, config: PaddleOCRVLConfig): + super().__init__(config) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + + +class Ernie4_5Attention(LlamaAttention): + def __init__(self, config: PaddleOCRVLConfig, layer_idx: int): + super().__init__(config, layer_idx) + + self.attention_dropout = 0.0 + self.rope_scaling = config.rope_scaling + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if "position_ids" in kwargs and kwargs["position_ids"] is not None: + position_ids = kwargs["position_ids"] + if position_ids.dim() == 3 and position_ids.shape[0] > 1: + kwargs["position_ids"] = position_ids[0:1] + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Ernie4_5RMSNorm(LlamaRMSNorm): + pass + + +class Ernie4_5DecoderLayer(LlamaDecoderLayer): + def __init__(self, config: PaddleOCRVLConfig, layer_idx: int): + super().__init__(config) + + self.self_attn = Ernie4_5Attention(config=config, layer_idx=layer_idx) + + self.mlp = Ernie4_5MLP(config) + self.input_layernorm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +@auto_docstring +class Ernie4_5PreTrainedModel(PreTrainedModel): + config: PaddleOCRVLConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Ernie4_5DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Ernie4_5DecoderLayer, + "attentions": Ernie4_5Attention, + } + + +class Ernie4_5Model(LlamaModel): + def __init__(self, config: PaddleOCRVLConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [Ernie4_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Ernie4_5RotaryEmbedding(config=config) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def paddleocr_vl_eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + logger.warning_once( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@auto_docstring +class PaddleOCRPreTrainedModel(PreTrainedModel): + config_class = PaddleOCRVLConfig + base_model_prefix = "PaddleOCR" + supports_gradient_checkpointing = True + + _no_split_modules = [ + "PaddleOCRTextEmbeddings", + "PaddleOCREncoderLayer", + "PaddleOCRVisionEmbeddings", + "PaddleOCRMultiheadAttentionPoolingHead", + ] + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, PaddleOCRVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, PaddleOCRVLConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, PaddleOCRAttention): + nn.init.xavier_uniform_(module.q_proj.weight) + nn.init.xavier_uniform_(module.k_proj.weight) + nn.init.xavier_uniform_(module.v_proj.weight) + nn.init.xavier_uniform_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, PaddleOCRMLP): + nn.init.xavier_uniform_(module.fc1.weight) + nn.init.xavier_uniform_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, PaddleOCRMultiheadAttentionPoolingHead): + nn.init.xavier_uniform_(module.probe.data) + nn.init.xavier_uniform_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class PaddleOCRVisionModel(SiglipVisionModel): + def __init__(self, config: PaddleOCRVisionConfig): + super().__init__(config) + + def forward( + self, + pixel_values, + position_ids: Optional[torch.Tensor] = None, + image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, + ) -> BaseModelOutputWithPooling: + """ + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): + The tensors corresponding to the input images. + position_ids (`torch.LongTensor` of shape `sequence_length`): + The position ids of the image. + image_grid_thw (`List[Tuple]`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + return self.vision_model( + pixel_values=pixel_values, + position_ids=position_ids, + image_grid_thw=image_grid_thw, + ) + + +class PaddleOCRVisionEmbeddings(SiglipVisionEmbeddings): + def __init__(self, config: PaddleOCRVisionConfig): + super().__init__() + self.packing_position_embedding = nn.Embedding(32768, self.embed_dim) + + @staticmethod + def flatten_list(image_grid_thw): + tmp_image_grid_thw = [] + for image_grid in image_grid_thw: + if isinstance(image_grid, list): + tmp_image_grid_thw.extend(image_grid) + else: + tmp_image_grid_thw.append(image_grid) + return tmp_image_grid_thw + + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int, is_after_patchify: bool = False + ) -> torch.Tensor: + num_positions = self.position_embedding.weight.shape[0] + + patch_pos_embed = self.position_embedding.weight.unsqueeze(0) + + dim = embeddings.shape[-1] + + if is_after_patchify: + new_height = height + new_width = width + else: + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bilinear", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward( + self, + pixel_values: torch.FloatTensor, + position_ids: Optional[torch.Tensor] = None, + image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, + ) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): + The tensors corresponding to the input images. + position_ids (`torch.LongTensor` of shape `sequence_length`): + The position ids of the image. + image_grid_thw (`List[Tuple]`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + if pixel_values.dim() == 5: + assert position_ids is not None + + batch_size, squence_len, channel, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w") + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(-2).squeeze(-1) + embeddings = rearrange(embeddings, "(b l) d -> b l d", b=batch_size, l=squence_len) + + if image_grid_thw is not None: + flatten_image_grid_thw = self.flatten_list(image_grid_thw) + assert batch_size == 1 + start = 0 + assert sum([np.prod(x) for x in flatten_image_grid_thw]) == embeddings.shape[1], ( + flatten_image_grid_thw, + embeddings.shape, + ) + embeddings = embeddings.squeeze(0) + tmp_embeddings = [] + for image_grid in image_grid_thw: + t, h, w = image_grid + end = start + t * h * w + image_embeddings = embeddings[start:end, :] + position_embedding = ( + self.interpolate_pos_encoding(image_embeddings, h, w, True).squeeze(0).repeat(t, 1) + ) + image_embeddings = image_embeddings + position_embedding + tmp_embeddings.append(image_embeddings) + start = end + embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0) + else: + embeddings = embeddings + self.packing_position_embedding(position_ids) + return embeddings + else: + raise NotImplementedError(str(pixel_values.shape)) + + +class PaddleOCRAttention(SiglipAttention): + def __init__(self, config: PaddleOCRVisionConfig): + super().__init__() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + cos, sin = rope_emb + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + queries, keys = apply_rotary_pos_emb_vision(queries, keys, cos, sin) + queries = queries.transpose(1, 2) + keys = keys.transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class PaddleOCRMLP(SiglipMLP): + pass + + +class PaddleOCRMultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead): + pass + + +class PaddleOCREncoderLayer(SiglipEncoderLayer): + def __init__(self, config: PaddleOCRVisionConfig): + super().__init__() + + +class PaddleOCREncoder(SiglipEncoder): + def __init__(self, config: PaddleOCRVisionConfig): + super().__init__() + embed_dim = config.hidden_size + num_heads = config.num_attention_heads + head_dim = embed_dim // num_heads + self.rotary_pos_emb = PaddleOCRRotaryEmbedding(head_dim // 2) + + @staticmethod + def flatten_list(image_grid_thw): + tmp_image_grid_thw = [] + for image_grid in image_grid_thw: + if isinstance(image_grid, list): + tmp_image_grid_thw.extend(image_grid) + else: + tmp_image_grid_thw.append(image_grid) + return tmp_image_grid_thw + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, + ) -> BaseModelOutput: + """ + Args: + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + The attention_mask used in forward function shape [batch_size X sequence_length] if not None. + image_grid_thw (`List[Tuple]`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + device = inputs_embeds.device + hidden_states = inputs_embeds + attention_mask = attention_mask.to(inputs_embeds.dtype) if attention_mask is not None else None + flatten_image_grid_thw = self.flatten_list(image_grid_thw) + assert sum([np.prod(x) for x in flatten_image_grid_thw]) == hidden_states.shape[1], ( + flatten_image_grid_thw, + hidden_states.shape, + ) + + split_hids = [] + split_wids = [] + for t, h, w in flatten_image_grid_thw: + image_pids = torch.arange(t * h * w, device=device) % (h * w) + sample_hids = image_pids // w + sample_wids = image_pids % w + split_hids.append(sample_hids) + split_wids.append(sample_wids) + width_position_ids = torch.concat(split_wids, dim=0) + height_position_ids = torch.concat(split_hids, dim=0) + + pids = torch.stack([height_position_ids, width_position_ids], dim=-1) + max_grid_size = pids.max() + 1 + rope_emb_max_grid = self.rotary_pos_emb(max_grid_size) + rope_emb = rope_emb_max_grid[pids].flatten(1) + rope_emb = rope_emb.repeat(1, 2) + rope_emb = (rope_emb.cos(), rope_emb.sin()) + + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + rope_emb=rope_emb, + ) + + return BaseModelOutput( + last_hidden_state=hidden_states, + ) + + +class PaddleOCRVisionTransformer(SiglipVisionTransformer): + def __init__(self, config: PaddleOCRVisionConfig): + super().__init__() + + def forward( + self, + pixel_values, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, + ) -> BaseModelOutputWithPooling: + """ + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`): + The tensors corresponding to the input images. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function shape [batch_size X sequence_length] if not None. + position_ids (`torch.LongTensor` of shape `sequence_length`): + The position ids of the image. + image_grid_thw (`List[Tuple]`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + hidden_states = self.embeddings(pixel_values, position_ids=position_ids, image_grid_thw=image_grid_thw) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + image_grid_thw=image_grid_thw, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=None, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@dataclass +class PaddleOCRVLCausalLMOutputWithPast(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[list[torch.FloatTensor]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class PaddleOCRVLForConditionalGeneration(Ernie4_5PreTrainedModel, GenerationMixin): + def __init__(self, config): + super().__init__(config) + self.mlp_AR = Projector(config, config.vision_config) + self.visual = PaddleOCRVisionModel(config.vision_config) + self.model = Ernie4_5Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.rope_deltas = None + + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + if torch.is_tensor(second_per_grid_t): + second_per_grid_t = second_per_grid_t.detach().item() + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[tuple, PaddleOCRVLCausalLMOutputWithPast]: + r""" + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + pixel_values = pixel_values.unsqueeze(0) + ppocr_position_ids = [] + image_grid_hws = [] + + for idx, thw in enumerate(image_grid_thw): + thw_tuple = tuple(thw.detach().cpu().numpy().tolist()) + numel = np.prod(thw_tuple) + image_grid_hws.append(thw_tuple) + image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) + ppocr_position_ids.append(image_position_ids) + + ppocr_position_ids = torch.concat(ppocr_position_ids, dim=0).to(pixel_values.device) + + vision_outputs = self.visual( + pixel_values=pixel_values, + image_grid_thw=image_grid_hws, + position_ids=ppocr_position_ids, + ) + image_embeds = vision_outputs.last_hidden_state + + image_embeds = self.mlp_AR(image_embeds, image_grid_thw) + + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + # image_embeds is a list of tensor, each tensor is a image feature,I want to concat them all into a tensor + image_embeds = torch.cat(image_embeds, dim=0) + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + # position_ids = None + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + return_dict=return_dict, + **kwargs, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return PaddleOCRVLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_cache=use_cache, + **kwargs, + ) + + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = [ + "pixel_values", + "image_grid_thw", + "pixel_values_videos", + "video_grid_thw", + "second_per_grid_ts", + ] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + # input_ids is required for expanding visual inputs + # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = [ + "PaddleOCRVLForConditionalGeneration", + "PaddleOCRVLConfig", + "PaddleOCRVisionConfig", + "PaddleOCRVLImageProcessor", + "PaddleOCRVLProcessor", +] diff --git a/src/transformers/models/paddleocr_vl/processing_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/processing_paddleocr_vl.py new file mode 100644 index 000000000000..aa2c58981ea0 --- /dev/null +++ b/src/transformers/models/paddleocr_vl/processing_paddleocr_vl.py @@ -0,0 +1,152 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_paddleocr_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +class PaddleOCRVLProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + } + + +class PaddleOCRVLProcessor(ProcessorMixin): + r""" + [`PaddleOCRVLProcessor`] offers all the functionalities of [`PaddleOCRVLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~PaddleOCRVLProcessor.__call__`] and [`~PaddleOCRVLProcessor.decode`] for more information. + Args: + image_processor ([`PaddleOCRVLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = [ + "chat_template", + "image_std", + "min_pixels", + "image_mean", + "merge_size", + "image_processor_type", + "temporal_patch_size", + "patch_size", + "max_pixels", + ] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = "<|IMAGE_PLACEHOLDER|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + **kwargs: Unpack[PaddleOCRVLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + PaddleOCRVLImageProcessor's [`~PaddleOCRVLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + PaddleOCRVLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if images is not None: + image_inputs = self.image_processor(images=images, return_tensors="pt") + image_inputs["pixel_values"] = image_inputs["pixel_values"] + image_grid_thw = image_inputs["image_grid_thw"] + + else: + image_inputs = {} + image_grid_thw = None + + if not isinstance(text, list): + text = [text] + + if image_grid_thw is not None: + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + text[i] = text[i].replace( + self.image_token, + "<|placeholder|>" + * ( + image_grid_thw[index].prod() + // self.image_processor.merge_size + // self.image_processor.merge_size + ), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs}) + + +__all__ = ["PaddleOCRVLProcessor"] From 0fe7b78ac0488096207e1843482cdacf3e311d63 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Fri, 21 Nov 2025 13:20:47 +0800 Subject: [PATCH 02/19] refactor --- src/transformers/modeling_utils.py | 1 + .../configuration_paddleocr_vl.py | 142 +- .../image_processing_paddleocr_vl.py | 4 +- .../image_processing_paddleocr_vl_fast.py | 202 +++ .../paddleocr_vl/modeling_paddleocr_vl.py | 1226 ++++++------- .../paddleocr_vl/modular_paddleocr_vl.py | 1537 ++++++----------- .../paddleocr_vl/processing_paddleocr_vl.py | 27 +- 7 files changed, 1446 insertions(+), 1693 deletions(-) create mode 100644 src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 960373ba102a..a23366208007 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -203,6 +203,7 @@ def is_local_dist_rank_0(): "qwen2_5_vl", "videollava", "vipllava", + "paddleocrvl", ] diff --git a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py index dc72f06943c3..80ba67984b19 100644 --- a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py @@ -23,12 +23,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ...configuration_utils import PretrainedConfig -from ...modeling_rope_utils import rope_config_validation +from typing import Optional +from ...configuration_utils import PreTrainedConfig, PretrainedConfig +from ...modeling_rope_utils import RopeParameters, rope_config_validation -class PaddleOCRVisionConfig(PretrainedConfig): - model_type = "paddleocr_vl" + +class PaddleOCRVLVisionConfig(PretrainedConfig): + model_type = "paddleocr_vl_vision" base_config_key = "vision_config" def __init__( @@ -65,7 +67,7 @@ def __init__( self.tokens_per_second = tokens_per_second -class PaddleOCRVLConfig(PretrainedConfig): +class PaddleOCRVLTextConfig(PretrainedConfig): """ Configuration class. @@ -73,11 +75,8 @@ class PaddleOCRVLConfig(PretrainedConfig): It inherits from PretrainedConfig and can be used to control model outputs. """ - model_type = "paddleocr_vl" - keys_to_ignore_at_inference = ["past_key_values"] - sub_configs = {"vision_config": PaddleOCRVisionConfig} + model_type = "paddleocr_vl_text" - # Default tensor parallel plan for base model `Qwen3` base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", @@ -101,9 +100,6 @@ def __init__( max_position_embeddings=32768, num_hidden_layers=2, num_attention_heads=2, - image_token_id=101304, - video_token_id=101305, - vision_start_token_id=101306, rms_norm_eps=1e-6, use_cache=False, use_flash_attention=False, @@ -122,8 +118,7 @@ def __init__( num_key_value_heads=None, max_sequence_length=None, tie_word_embeddings=False, - vision_config=None, - rope_scaling=None, + rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, **kwargs, ): """ @@ -161,10 +156,6 @@ def __init__( eos_token_id=eos_token_id, **kwargs, ) - if isinstance(vision_config, dict): - self.vision_config = self.sub_configs["vision_config"](**vision_config) - elif vision_config is None: - self.vision_config = self.sub_configs["vision_config"]() self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size @@ -177,9 +168,6 @@ def __init__( self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id - self.image_token_id = image_token_id - self.video_token_id = video_token_id - self.vision_start_token_id = vision_start_token_id self.head_dim = head_dim self.hidden_act = hidden_act self.sliding_window = None @@ -193,13 +181,113 @@ def __init__( self.compression_ratio = compression_ratio self.num_key_value_heads = num_key_value_heads self.max_sequence_length = max_sequence_length - self.rope_scaling = rope_scaling - if self.rope_scaling is not None and "type" in self.rope_scaling: - if self.rope_scaling["type"] == "mrope": - self.rope_scaling["type"] = "default" - self.rope_scaling["rope_type"] = self.rope_scaling["type"] + # Try to set `rope_scaling` if available, otherwise use `rope_parameters` + rope_scaling = kwargs.pop("rope_scaling", None) + self.rope_parameters = rope_scaling or rope_parameters + if self.rope_parameters is not None and self.rope_parameters["rope_type"] == "mrope": + self.rope_parameters["rope_type"] = "default" rope_config_validation(self, ignore_keys={"mrope_section"}) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) -__all__ = ["PaddleOCRVLConfig", "PaddleOCRVisionConfig"] +class PaddleOCRVLConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PaddleOCRVLModel`]. It is used to instantiate a + Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PaddleOCRVLTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PaddleOCRVLVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151656): + The video token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 151652): + The token index to denote start of vision input. + vision_end_token_id (`int`, *optional*, defaults to 151653): + The token index to denote end of vision input. + + ```python + >>> from transformers import PaddleOCRVLForConditionalGeneration, PaddleOCRVLConfig + + >>> # Initializing a PaddleOCRVL style configuration + >>> configuration = PaddleOCRVLConfig() + + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = PaddleOCRVLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "paddleocr_vl" + sub_configs = {"vision_config": PaddleOCRVLVisionConfig, "text_config": PaddleOCRVLTextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=151655, + video_token_id=151656, + vision_start_token_id=151652, + vision_end_token_id=151653, + **kwargs, + ): + # We need to init super() here so that it does not reset values + # that are in text config to the BaseClass defaults. The Base + # config has many text related defaults and not all defaults are same as for `PaddleOCRVLTextConfig` + super().__init__(**kwargs) + + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + # For BC use all kwargs to init `TextConfig` + self.text_config = self.sub_configs["text_config"](**kwargs) + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + + # Attention implementation to use. It sets it recursively on sub-configs so we call it again in the end + self._attn_implementation = kwargs.pop("attn_implementation", None) + + def __setattr__(self, key, value): + if ( + (text_config := super().__getattribute__("__dict__").get("text_config")) is not None + and key not in ["_name_or_path", "model_type", "dtype", "_attn_implementation_internal"] + and key in text_config.__dict__ + ): + setattr(text_config, key, value) + else: + super().__setattr__(key, value) + + def __getattribute__(self, key): + if "text_config" in super().__getattribute__("__dict__") and key not in [ + "_name_or_path", + "model_type", + "dtype", + "_attn_implementation_internal", + ]: + text_config = super().__getattribute__("text_config") + if key in text_config.__dict__: + return getattr(text_config, key) + + return super().__getattribute__(key) + + +__all__ = ["PaddleOCRVLConfig", "PaddleOCRVLTextConfig"] diff --git a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py index d2b7dad85fa1..8cf26f0479e0 100644 --- a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py @@ -174,8 +174,8 @@ def __init__( size["shortest_edge"] = min_pixels if max_pixels is not None: size["longest_edge"] = max_pixels - self.min_pixels = min_pixels - self.max_pixels = max_pixels + self.min_pixels = size["shortest_edge"] + self.max_pixels = size["longest_edge"] self.size = size self.do_resize = do_resize self.resample = resample diff --git a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py new file mode 100644 index 000000000000..1a7cb510eea5 --- /dev/null +++ b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py @@ -0,0 +1,202 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_paddleocr_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union + +import torch +import torch.nn.functional as F + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images +from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling, SizeDict +from ...utils import TensorType + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 28 * 28 * 130, + max_pixels: int = 28 * 28 * 1280, +): + if height < factor: + width = round((width * factor) / height) + height = factor + + if width < factor: + height = round((height * factor) / width) + width = factor + + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +class PaddleOCRVLImageProcessorFast(BaseImageProcessorFast): + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: bool = True, + min_pixels: int = 28 * 28 * 130, + max_pixels: int = 28 * 28 * 1280, + patch_size: int = 14, + temporal_patch_size: int = 1, + merge_size: int = 2, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if size is not None and ("shortest_edge" not in size or "longest_edge" not in size): + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + else: + size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280} + # backward compatibility: override size with min_pixels and max_pixels if they are provided + if min_pixels is not None: + size["shortest_edge"] = min_pixels + if max_pixels is not None: + size["longest_edge"] = max_pixels + self.min_pixels = size["shortest_edge"] + self.max_pixels = size["longest_edge"] + self.size = size + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.merge_size = merge_size + self.do_convert_rgb = do_convert_rgb + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ): + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + height, width = stacked_images.shape[-2:] + if do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=self.patch_size * self.merge_size, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + stacked_images = self.resize( + image=stacked_images, + size=SizeDict(height=resized_height, width=resized_width), + interpolation=interpolation, + ) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + processed_grids = {} + for shape, stacked_images in grouped_images.items(): + resized_height, resized_width = stacked_images.shape[-2:] + # Fused rescale and normalize + patches = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + + if patches.ndim == 4: + # add a temporal dimension if we have images + patches = patches.unsqueeze(1) + if patches.shape[1] % self.temporal_patch_size != 0: + repeats = patches[:, -1:].repeat(1, self.temporal_patch_size - 1, 1, 1, 1) + patches = torch.cat([patches, repeats], dim=1) + + batch_size, grid_t, channel = patches.shape[:3] + grid_t = grid_t // self.temporal_patch_size + grid_h, grid_w = ( + resized_height // self.patch_size, + resized_width // self.patch_size, + ) + patches = patches.view( + batch_size, + grid_t, + self.temporal_patch_size, + channel, + grid_h, + self.patch_size, + grid_w, + self.patch_size, + ) + patches = patches.permute(0, 1, 4, 6, 3, 2, 5, 7) + flatten_patches = patches.reshape( + batch_size, grid_t * grid_h * grid_w, channel, self.patch_size, self.patch_size + ) + + processed_images_grouped[shape] = flatten_patches + processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_grids = reorder_images(processed_grids, grouped_images_index) + pixel_values = torch.cat(processed_images, dim=0) + image_grid_thw = torch.tensor(processed_grids) + + return BatchFeature( + data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors + ) + + +__all__ = ["PaddleOCRVLImageProcessorFast"] diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 1585295c0c33..88a80d678a86 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -23,41 +23,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional, Union -import numpy as np import torch -from einops import rearrange from torch import nn -from torch.nn import CrossEntropyLoss -from torch.nn.init import _calculate_fan_in_and_fan_out from ...activations import ACT2FN, GELUActivation from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...masking_utils import create_causal_mask +from ...masking_utils import create_bidirectional_mask, create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, logging, torch_int +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int from ...utils.generic import check_model_inputs -from .configuration_paddleocr_vl import PaddleOCRVisionConfig, PaddleOCRVLConfig +from .configuration_paddleocr_vl import PaddleOCRVLConfig, PaddleOCRVLTextConfig, PaddleOCRVLVisionConfig logger = logging.get_logger(__name__) -class Projector(nn.Module): - def __init__(self, text_config: PaddleOCRVLConfig, vision_config: PaddleOCRVisionConfig): +class PaddleOCRProjector(nn.Module): + def __init__(self, config: PaddleOCRVLConfig): super().__init__() - self.text_config = text_config - self.vision_config = vision_config + self.text_config = config.text_config + self.vision_config = config.vision_config self.merge_kernel_size = (2, 2) self.hidden_size = self.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1] @@ -67,26 +63,32 @@ def __init__(self, text_config: PaddleOCRVLConfig, vision_config: PaddleOCRVisio self.act = GELUActivation() self.linear_2 = nn.Linear(self.hidden_size, self.text_config.hidden_size, bias=True) - def forward(self, image_features: torch.Tensor, image_grid_thw: list[tuple[int, int, int]]) -> torch.Tensor: + def forward(self, image_features: torch.Tensor, image_grid_thw: torch.Tensor) -> torch.Tensor: + image_features_chunks = image_features.split(image_grid_thw.prod(dim=1).tolist(), dim=0) m1, m2 = self.merge_kernel_size processed_features = [] - for image_feature, image_grid in zip(image_features, image_grid_thw): + for image_feature, image_grid in zip(image_features_chunks, image_grid_thw): image_feature = self.pre_norm(image_feature) t, h, w = image_grid + d = image_feature.shape[-1] + h_block = h // m1 + w_block = w // m2 + + image_feature = image_feature.reshape(t, h_block, m1, w_block, m2, d) + image_feature = image_feature.reshape((t * h_block * w_block), (m1 * m2 * d)) - image_feature = rearrange( - image_feature, "(t h p1 w p2) d -> (t h w) (p1 p2 d)", t=t, h=h // m1, p1=m1, w=w // m2, p2=m2 - ) hidden_states = self.linear_1(image_feature) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) processed_features.append(hidden_states) - return processed_features + return torch.cat(processed_features, dim=0) -class PaddleOCRRotaryEmbedding(nn.Module): +class PaddleOCRVisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) @@ -98,7 +100,7 @@ def forward(self, seqlen: int) -> torch.Tensor: return freqs -class Ernie4_5RotaryEmbedding(nn.Module): +class PaddleOCRRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: PaddleOCRVLConfig, device=None): @@ -147,29 +149,25 @@ def compute_default_rope_parameters( ) return inv_freq, attention_factor - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + # Ignore copy def forward(self, x, position_ids): + # In contrast to other models, PaddleOCR has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class Ernie4_5MLP(nn.Module): - def __init__(self, config: PaddleOCRVLConfig): +class PaddleOCRMLP(nn.Module): + def __init__(self, config: PaddleOCRVLTextConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -185,6 +183,51 @@ def forward(self, x): return down_proj +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). @@ -230,101 +273,70 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim return q_embed, k_embed -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: +class PaddleOCRAttention(nn.Module): """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs: Unpack[TransformersKwargs], -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - -class Ernie4_5Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: PaddleOCRVLConfig, layer_idx: int): + def __init__(self, config: PaddleOCRVLConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) - self.attention_dropout = 0.0 + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.is_causal = True - self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) - self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) - self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) - self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) - self.rope_scaling = config.rope_scaling + self.attention_dropout = 0.0 + self.rope_parameters = config.rope_parameters + self.scaling = self.head_dim**-0.5 + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias) + self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None + self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None def forward( self, hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.Tensor, torch.Tensor]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - if "position_ids" in kwargs and kwargs["position_ids"] is not None: - position_ids = kwargs["position_ids"] - if position_ids.dim() == 3 and position_ids.shape[0] > 1: - kwargs["position_ids"] = position_ids[0:1] + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + query_states, key_states, cos, sin, self.config.rope_parameters["mrope_section"] ) - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -338,19 +350,21 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, + sliding_window=self.sliding_window, + position_ids=position_ids, # pass positions for FA2 **kwargs, ) - attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights @use_kernel_forward_from_hub("RMSNorm") -class Ernie4_5RMSNorm(nn.Module): +class PaddleOCRRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ - Ernie4_5RMSNorm is equivalent to T5LayerNorm + PaddleOCRRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) @@ -367,16 +381,16 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Ernie4_5DecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: PaddleOCRVLConfig, layer_idx: int): +class PaddleOCRDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: PaddleOCRVLTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Ernie4_5Attention(config=config, layer_idx=layer_idx) + self.self_attn = PaddleOCRAttention(config=config, layer_idx=layer_idx) - self.mlp = Ernie4_5MLP(config) - self.input_layernorm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = PaddleOCRMLP(config) + self.input_layernorm = PaddleOCRRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = PaddleOCRRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -413,11 +427,11 @@ def forward( @auto_docstring -class Ernie4_5PreTrainedModel(PreTrainedModel): +class PaddleOCRVLPreTrainedModel(PreTrainedModel): config: PaddleOCRVLConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Ernie4_5DecoderLayer"] + _no_split_modules = ["PaddleOCRDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True @@ -425,25 +439,21 @@ class Ernie4_5PreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True _supports_attention_backend = True - _can_record_outputs = { - "hidden_states": Ernie4_5DecoderLayer, - "attentions": Ernie4_5Attention, - } @auto_docstring -class Ernie4_5Model(Ernie4_5PreTrainedModel): - def __init__(self, config: PaddleOCRVLConfig): +class PaddleOCRVLTextModel(PaddleOCRVLPreTrainedModel): + def __init__(self, config: PaddleOCRVLTextConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( - [Ernie4_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [PaddleOCRDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Ernie4_5RotaryEmbedding(config=config) + self.norm = PaddleOCRRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = PaddleOCRRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -510,169 +520,12 @@ def forward( ) -def _trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - logger.warning_once( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - - -def trunc_normal_tf_( - tensor: torch.Tensor, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -) -> torch.Tensor: - """Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \\leq \text{mean} \\leq b`. - - NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the - bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 - and the result is subsequently scaled and shifted by the mean and std args. - - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - """ - with torch.no_grad(): - _trunc_normal_(tensor, 0, 1.0, a, b) - tensor.mul_(std).add_(mean) - - -def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - if mode == "fan_in": - denom = fan_in - elif mode == "fan_out": - denom = fan_out - elif mode == "fan_avg": - denom = (fan_in + fan_out) / 2 - - variance = scale / denom - - if distribution == "truncated_normal": - # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) - elif distribution == "normal": - with torch.no_grad(): - tensor.normal_(std=math.sqrt(variance)) - elif distribution == "uniform": - bound = math.sqrt(3 * variance) - with torch.no_grad(): - tensor.uniform_(-bound, bound) - else: - raise ValueError(f"invalid distribution {distribution}") - - -def lecun_normal_(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") - - -def default_flax_embed_init(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="normal") - - -@auto_docstring -class PaddleOCRPreTrainedModel(PreTrainedModel): - config_class = PaddleOCRVLConfig - base_model_prefix = "PaddleOCR" - supports_gradient_checkpointing = True - - _no_split_modules = [ - "PaddleOCRTextEmbeddings", - "PaddleOCREncoderLayer", - "PaddleOCRVisionEmbeddings", - "PaddleOCRMultiheadAttentionPoolingHead", - ] - _supports_flash_attn_2 = True - _supports_sdpa = True - - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, PaddleOCRVisionEmbeddings): - width = ( - self.config.vision_config.hidden_size - if isinstance(self.config, PaddleOCRVLConfig) - else self.config.hidden_size - ) - nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) - elif isinstance(module, nn.Embedding): - default_flax_embed_init(module.weight) - elif isinstance(module, PaddleOCRAttention): - nn.init.xavier_uniform_(module.q_proj.weight) - nn.init.xavier_uniform_(module.k_proj.weight) - nn.init.xavier_uniform_(module.v_proj.weight) - nn.init.xavier_uniform_(module.out_proj.weight) - nn.init.zeros_(module.q_proj.bias) - nn.init.zeros_(module.k_proj.bias) - nn.init.zeros_(module.v_proj.bias) - nn.init.zeros_(module.out_proj.bias) - elif isinstance(module, PaddleOCRMLP): - nn.init.xavier_uniform_(module.fc1.weight) - nn.init.xavier_uniform_(module.fc2.weight) - nn.init.normal_(module.fc1.bias, std=1e-6) - nn.init.normal_(module.fc2.bias, std=1e-6) - elif isinstance(module, PaddleOCRMultiheadAttentionPoolingHead): - nn.init.xavier_uniform_(module.probe.data) - nn.init.xavier_uniform_(module.attention.in_proj_weight.data) - nn.init.zeros_(module.attention.in_proj_bias.data) - elif isinstance(module, (nn.Linear, nn.Conv2d)): - lecun_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - -@auto_docstring( - custom_intro=""" - The vision model from PaddleOCR without any head or projection on top. - """ -) -class PaddleOCRVisionModel(PaddleOCRPreTrainedModel): - config: PaddleOCRVisionConfig +class PaddleOCRVisionModel(PaddleOCRVLPreTrainedModel): + config: PaddleOCRVLVisionConfig main_input_name = "pixel_values" input_modalities = "image" - def __init__(self, config: PaddleOCRVisionConfig): + def __init__(self, config: PaddleOCRVLVisionConfig): super().__init__(config) self.vision_model = PaddleOCRVisionTransformer(config) @@ -680,35 +533,30 @@ def __init__(self, config: PaddleOCRVisionConfig): # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self) -> nn.Module: - return self.vision_model.embeddings.patch_embedding - - @check_model_inputs(tie_last_hidden_states=False) - @auto_docstring def forward( self, pixel_values, - position_ids: Optional[torch.Tensor] = None, + cu_seqlens, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, ) -> BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. - position_ids (`torch.LongTensor` of shape `sequence_length`): - The position ids of the image. - image_grid_thw (`List[Tuple]`, *optional*): + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): + The cumulative sequence lengths of each image or video feature. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ return self.vision_model( pixel_values=pixel_values, - position_ids=position_ids, + cu_seqlens=cu_seqlens, image_grid_thw=image_grid_thw, ) class PaddleOCRVisionEmbeddings(nn.Module): - def __init__(self, config: PaddleOCRVisionConfig): + def __init__(self, config: PaddleOCRVLVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -727,11 +575,8 @@ def __init__(self, config: PaddleOCRVisionConfig): self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - self.packing_position_embedding = nn.Embedding(32768, self.embed_dim) - def interpolate_pos_encoding( - self, embeddings: torch.Tensor, height: int, width: int, is_after_patchify: bool = False - ) -> torch.Tensor: + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing and no class embeddings. @@ -746,12 +591,8 @@ def interpolate_pos_encoding( dim = embeddings.shape[-1] - if is_after_patchify: - new_height = height - new_width = width - else: - new_height = height // self.patch_size - new_width = width // self.patch_size + new_height = height + new_width = width sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) @@ -770,7 +611,6 @@ def interpolate_pos_encoding( def forward( self, pixel_values: torch.FloatTensor, - position_ids: Optional[torch.Tensor] = None, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, ) -> torch.Tensor: """ @@ -779,55 +619,33 @@ def forward( The tensors corresponding to the input images. position_ids (`torch.LongTensor` of shape `sequence_length`): The position ids of the image. - image_grid_thw (`List[Tuple]`, *optional*): + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - if pixel_values.dim() == 5: - assert position_ids is not None - - batch_size, squence_len, channel, height, width = pixel_values.shape - target_dtype = self.patch_embedding.weight.dtype - pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w") - patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] - embeddings = patch_embeds.flatten(-2).squeeze(-1) - embeddings = rearrange(embeddings, "(b l) d -> b l d", b=batch_size, l=squence_len) - - if image_grid_thw is not None: - flatten_image_grid_thw = self.flatten_list(image_grid_thw) - assert batch_size == 1 - start = 0 - assert sum([np.prod(x) for x in flatten_image_grid_thw]) == embeddings.shape[1], ( - flatten_image_grid_thw, - embeddings.shape, - ) - embeddings = embeddings.squeeze(0) - tmp_embeddings = [] - for image_grid in image_grid_thw: - t, h, w = image_grid - end = start + t * h * w - image_embeddings = embeddings[start:end, :] - position_embedding = ( - self.interpolate_pos_encoding(image_embeddings, h, w, True).squeeze(0).repeat(t, 1) - ) - image_embeddings = image_embeddings + position_embedding - tmp_embeddings.append(image_embeddings) - start = end - embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0) - else: - embeddings = embeddings + self.packing_position_embedding(position_ids) - return embeddings - else: - raise NotImplementedError(str(pixel_values.shape)) - - @staticmethod - def flatten_list(image_grid_thw): - tmp_image_grid_thw = [] + batch_size, squence_len, channel, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + pixel_values = pixel_values.reshape(batch_size * squence_len, channel, height, width) + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(-2).squeeze(-1) + embeddings = embeddings.reshape(batch_size, squence_len, -1) + + assert batch_size == 1, ( + f"Batch size must be 1, but received {batch_size}. This model only processes one image at a time." + ) + start = 0 + embeddings = embeddings.squeeze(0) + tmp_embeddings = [] for image_grid in image_grid_thw: - if isinstance(image_grid, list): - tmp_image_grid_thw.extend(image_grid) - else: - tmp_image_grid_thw.append(image_grid) - return tmp_image_grid_thw + t, h, w = image_grid + end = start + t * h * w + image_embeddings = embeddings[start:end, :] + position_embedding = self.interpolate_pos_encoding(image_embeddings, h, w).squeeze(0).repeat(t, 1) + image_embeddings = image_embeddings + position_embedding + tmp_embeddings.append(image_embeddings) + start = end + embeddings = torch.concat(tmp_embeddings, dim=0) + + return embeddings def apply_rotary_pos_emb_vision( @@ -844,10 +662,10 @@ def apply_rotary_pos_emb_vision( return q_embed, k_embed -class PaddleOCRAttention(nn.Module): +class PaddleOCRVisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: PaddleOCRVisionConfig): + def __init__(self, config: PaddleOCRVLVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -858,59 +676,99 @@ def __init__(self, config: PaddleOCRVisionConfig): f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) - self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.num_key_value_groups = 1 + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + cu_seqlens: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """Input shape: Batch x Time x Channel""" - batch_size, seq_length, embed_dim = hidden_states.shape + """ + Args: + hidden_states (`torch.Tensor`): + Input to the layer of shape `(seq_len, embed_dim)`. + cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`): + The cumulative sequence lengths of each image or video feature. + position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`): + The cosine and sine position embeddings for vision attention. + """ + seq_length = hidden_states.shape[0] + query_states = self.q_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim) - queries = self.q_proj(hidden_states) - keys = self.k_proj(hidden_states) - values = self.v_proj(hidden_states) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) - cos, sin = rope_emb - queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim) - keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim) - values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - queries, keys = apply_rotary_pos_emb_vision(queries, keys, cos, sin) - queries = queries.transpose(1, 2) - keys = keys.transpose(1, 2) + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, attn_weights = attention_interface( - self, - queries, - keys, - values, - attention_mask, - is_causal=self.is_causal, - scaling=self.scale, - dropout=0.0 if not self.training else self.dropout, - ) + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs, attn_weights = [], [] + for q, k, v in zip(*splits): + attn_output, attn_weight = attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + ) + attn_outputs.append(attn_output) + attn_weights.append(attn_weight) + + attn_output = torch.cat(attn_outputs, dim=1) - attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.out_proj(attn_output) return attn_output, attn_weights -class PaddleOCRMLP(nn.Module): - def __init__(self, config): +class PaddleOCRVisionMLP(nn.Module): + def __init__(self, config: PaddleOCRVLVisionConfig): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] @@ -924,16 +782,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class PaddleOCRMultiheadAttentionPoolingHead(nn.Module): +class PaddleOCRVisionMultiheadAttentionPoolingHead(nn.Module): """Multihead Attention Pooling.""" - def __init__(self, config: PaddleOCRVisionConfig): + def __init__(self, config: PaddleOCRVLVisionConfig): super().__init__() self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.mlp = PaddleOCRMLP(config) + + self.mlp = PaddleOCRVisionMLP(config) def forward(self, hidden_state): batch_size = hidden_state.shape[0] @@ -948,28 +807,36 @@ def forward(self, hidden_state): return hidden_state[:, 0] -class PaddleOCREncoderLayer(GradientCheckpointingLayer): - def __init__(self, config: PaddleOCRVisionConfig): +class PaddleOCRVisionEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: PaddleOCRVLVisionConfig): super().__init__() self.embed_dim = config.hidden_size self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.self_attn = PaddleOCRAttention(config) + self.self_attn = PaddleOCRVisionAttention(config=config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = PaddleOCRMLP(config) + self.mlp = PaddleOCRVisionMLP(config=config) @auto_docstring def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], **kwargs: Unpack[TransformersKwargs], - ) -> torch.FloatTensor: + ) -> torch.Tensor: + r""" + cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`): + The cumulative sequence lengths of each image or video feature. + position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`): + The cosine and sine position embeddings for vision attention. + """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -982,52 +849,54 @@ def forward( return hidden_states -class PaddleOCREncoder(nn.Module): +class PaddleOCRVisionEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`PaddleOCREncoderLayer`]. + [`PaddleOCRVisionEncoderLayer`]. Args: - config: PaddleOCRConfig + config: PaddleOCRVisionConfig """ - def __init__(self, config: PaddleOCRVisionConfig): + def __init__(self, config: PaddleOCRVLVisionConfig): super().__init__() self.config = config - self.layers = nn.ModuleList([PaddleOCREncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([PaddleOCRVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False embed_dim = config.hidden_size num_heads = config.num_attention_heads head_dim = embed_dim // num_heads - self.rotary_pos_emb = PaddleOCRRotaryEmbedding(head_dim // 2) + self.rotary_pos_emb = PaddleOCRVisionRotaryEmbedding(head_dim // 2) # Ignore copy + @can_return_tuple @auto_docstring def forward( self, inputs_embeds, + cu_seqlens, attention_mask: Optional[torch.Tensor] = None, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, ) -> BaseModelOutput: """ Args: + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): + The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - image_grid_thw (`List[Tuple]`, *optional*): + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ device = inputs_embeds.device hidden_states = inputs_embeds - attention_mask = attention_mask.to(inputs_embeds.dtype) if attention_mask is not None else None - flatten_image_grid_thw = self.flatten_list(image_grid_thw) - assert sum([np.prod(x) for x in flatten_image_grid_thw]) == hidden_states.shape[1], ( - flatten_image_grid_thw, - hidden_states.shape, + attention_mask = create_bidirectional_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, ) - split_hids = [] split_wids = [] - for t, h, w in flatten_image_grid_thw: + for t, h, w in image_grid_thw: image_pids = torch.arange(t * h * w, device=device) % (h * w) sample_hids = image_pids // w sample_wids = image_pids % w @@ -1038,75 +907,61 @@ def forward( pids = torch.stack([height_position_ids, width_position_ids], dim=-1) max_grid_size = pids.max() + 1 - rope_emb_max_grid = self.rotary_pos_emb(max_grid_size) - rope_emb = rope_emb_max_grid[pids].flatten(1) - rope_emb = rope_emb.repeat(1, 2) - rope_emb = (rope_emb.cos(), rope_emb.sin()) + rotary_embeddings_max_grid = self.rotary_pos_emb(max_grid_size) + rotary_embeddings = rotary_embeddings_max_grid[pids].flatten(1) + rotary_embeddings = rotary_embeddings.repeat(1, 2) + position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, - attention_mask, - rope_emb=rope_emb, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, ) return BaseModelOutput( last_hidden_state=hidden_states, ) - @staticmethod - def flatten_list(image_grid_thw): - tmp_image_grid_thw = [] - for image_grid in image_grid_thw: - if isinstance(image_grid, list): - tmp_image_grid_thw.extend(image_grid) - else: - tmp_image_grid_thw.append(image_grid) - return tmp_image_grid_thw - -class PaddleOCRVisionTransformer(PaddleOCRPreTrainedModel): - _can_record_outputs = { - "hidden_states": PaddleOCREncoderLayer, - "attentions": PaddleOCRAttention, - } - - def __init__(self, config: PaddleOCRVisionConfig): +class PaddleOCRVisionTransformer(PaddleOCRVLPreTrainedModel): + def __init__(self, config: PaddleOCRVLVisionConfig): super().__init__(config) self.config = config embed_dim = config.hidden_size self.embeddings = PaddleOCRVisionEmbeddings(config) - self.encoder = PaddleOCREncoder(config) + self.encoder = PaddleOCRVisionEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head if self.use_head: - self.head = PaddleOCRMultiheadAttentionPoolingHead(config) + self.head = PaddleOCRVisionMultiheadAttentionPoolingHead(config) - @check_model_inputs(tie_last_hidden_states=False) - @auto_docstring def forward( self, pixel_values, + cu_seqlens: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, ) -> BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`): The tensors corresponding to the input images. + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): + The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. position_ids (`torch.LongTensor` of shape `sequence_length`): The position ids of the image. - image_grid_thw (`List[Tuple]`, *optional*): + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - hidden_states = self.embeddings(pixel_values, position_ids=position_ids, image_grid_thw=image_grid_thw) + hidden_states = self.embeddings(pixel_values, image_grid_thw=image_grid_thw) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, + cu_seqlens=cu_seqlens, attention_mask=attention_mask, image_grid_thw=image_grid_thw, ) @@ -1123,51 +978,93 @@ def forward( @dataclass +@auto_docstring( + custom_intro=""" + Base class for Llava outputs, with hidden states and attentions. + """ +) +class PaddleOCRVLModelOutputWithPast(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for PaddleOCRVL causal language model (or autoregressive) outputs. + """ +) class PaddleOCRVLCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[list[torch.FloatTensor]] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None rope_deltas: Optional[torch.LongTensor] = None -class PaddleOCRVLForConditionalGeneration(Ernie4_5PreTrainedModel, GenerationMixin): - def __init__(self, config): +@auto_docstring +class PaddleOCRVLModel(PaddleOCRVLPreTrainedModel): + base_model_prefix = "" + _checkpoint_conversion_mapping = {"^model": "language_model"} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + _keys_to_ignore_on_load_unexpected = ["packing_position_embedding"] + + def __init__(self, config: PaddleOCRVLConfig): super().__init__(config) - self.mlp_AR = Projector(config, config.vision_config) self.visual = PaddleOCRVisionModel(config.vision_config) - self.model = Ernie4_5Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.language_model = PaddleOCRVLTextModel(config.text_config) self.rope_deltas = None + self.projector = PaddleOCRProjector(config) + # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): - return self.model.embed_tokens + return self.language_model.embed_tokens def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings + self.language_model.embed_tokens = value def set_decoder(self, decoder): - self.model = decoder + self.language_model = decoder def get_decoder(self): - return self.model + return self.language_model def get_rope_index( self, input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -1186,21 +1083,14 @@ def get_rope_index( For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part and 1D rotary position embedding for text part. Examples: - Temporal (Time): 3 patches, representing different segments of the video in time. - Height: 2 patches, dividing each frame vertically. - Width: 2 patches, dividing each frame horizontally. - We also have some important parameters: - fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. - tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. - temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. - interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches. input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. - vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] - text temporal position_ids: [101, 102, 103, 104, 105] - text height position_ids: [101, 102, 103, 104, 105] - text width position_ids: [101, 102, 103, 104, 105] + text temporal position_ids: [3, 4, 5, 6, 7] + text height position_ids: [3, 4, 5, 6, 7] + text width position_ids: [3, 4, 5, 6, 7] Here we calculate the text start position_ids as the max vision position_ids plus 1. Args: @@ -1211,8 +1101,6 @@ def get_rope_index( The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): - The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: @@ -1233,16 +1121,11 @@ def get_rope_index( if attention_mask is None: attention_mask = torch.ones_like(total_input_ids) position_ids = torch.ones( - 3, - input_ids.shape[0], - input_ids.shape[1], - dtype=input_ids.dtype, - device=input_ids.device, + 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device ) image_index, video_index = 0, 0 - attention_mask = attention_mask.to(total_input_ids.device) for i, input_ids in enumerate(total_input_ids): - input_ids = input_ids[attention_mask[i] == 1] + input_ids = input_ids[attention_mask[i].to(input_ids.device) == 1] image_nums, video_nums = 0, 0 vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] @@ -1267,21 +1150,15 @@ def get_rope_index( image_grid_thw[image_index][1], image_grid_thw[image_index][2], ) - second_per_grid_t = 0 image_index += 1 remain_images -= 1 ed = ed_image - else: t, h, w = ( video_grid_thw[video_index][0], video_grid_thw[video_index][1], video_grid_thw[video_index][2], ) - if second_per_grid_ts is not None: - second_per_grid_t = second_per_grid_ts[video_index] - else: - second_per_grid_t = 1.0 video_index += 1 remain_videos -= 1 ed = ed_video @@ -1295,16 +1172,7 @@ def get_rope_index( st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - if torch.is_tensor(second_per_grid_t): - second_per_grid_t = second_per_grid_t.detach().item() - range_tensor = torch.arange(llm_grid_t).view(-1, 1) - expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) - - time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second - - time_tensor_long = time_tensor.long() - t_index = time_tensor_long.flatten() - + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) @@ -1341,6 +1209,75 @@ def get_rope_index( return position_ids, mrope_position_deltas + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + video_embeds = torch.split(video_embeds, split_sizes) + return video_embeds + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0) + cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=image_grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + vision_outputs = self.visual( + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + cu_seqlens=cu_seqlens, + ) + image_embeds = vision_outputs.last_hidden_state + return image_embeds + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + return image_mask + + @can_return_tuple def forward( self, input_ids: torch.LongTensor = None, @@ -1348,70 +1285,31 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, **kwargs, - ) -> Union[tuple, PaddleOCRVLCausalLMOutputWithPast]: + ) -> Union[tuple, PaddleOCRVLModelOutputWithPast]: r""" - Returns: + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.dtype) - pixel_values = pixel_values.unsqueeze(0) - ppocr_position_ids = [] - image_grid_hws = [] - - for idx, thw in enumerate(image_grid_thw): - thw_tuple = tuple(thw.detach().cpu().numpy().tolist()) - numel = np.prod(thw_tuple) - image_grid_hws.append(thw_tuple) - image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) - ppocr_position_ids.append(image_position_ids) - - ppocr_position_ids = torch.concat(ppocr_position_ids, dim=0).to(pixel_values.device) - - vision_outputs = self.visual( - pixel_values=pixel_values, - image_grid_thw=image_grid_hws, - position_ids=ppocr_position_ids, - ) - image_embeds = vision_outputs.last_hidden_state - - image_embeds = self.mlp_AR(image_embeds, image_grid_thw) - - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - # image_embeds is a list of tensor, each tensor is a image feature,I want to concat them all into a tensor - image_embeds = torch.cat(image_embeds, dim=0) - n_image_features = image_embeds.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - - mask = input_ids == self.config.image_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - image_mask = mask_expanded.to(inputs_embeds.device) - - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = self.language_model.embed_tokens(input_ids) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = self.projector(image_embeds, image_grid_thw) + image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - # position_ids = None # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only @@ -1421,11 +1319,9 @@ def forward( or (past_key_values is None or past_key_values.get_seq_length() == 0) ): position_ids, rope_deltas = self.get_rope_index( - input_ids, - image_grid_thw, - video_grid_thw, - second_per_grid_ts, - attention_mask, + input_ids=input_ids, + image_grid_thw=image_grid_thw, + attention_mask=attention_mask, ) self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids @@ -1443,7 +1339,7 @@ def forward( position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - outputs = self.model( + outputs = self.language_model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, @@ -1454,27 +1350,154 @@ def forward( **kwargs, ) - hidden_states = outputs[0] + output = PaddleOCRVLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + return output if return_dict else output.to_tuple() + + +class PaddleOCRVLForConditionalGeneration(PaddleOCRVLPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^visual": "model.visual", + "^mlp_AR": "model.projector", + r"^model(?!(\.visual|\.projector))": "model.language_model", + } + _tied_weights_keys = ["lm_head.weight"] + _keys_to_ignore_on_load_unexpected = ["packing_position_embedding"] + + def __init__(self, config): + super().__init__(config) + self.model = PaddleOCRVLModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + return self.model.get_video_features(pixel_values_videos, video_grid_thw) + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + return self.model.get_image_features(pixel_values, image_grid_thw) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def visual(self): + return self.model.visual + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, PaddleOCRVLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + + Example: + + ```python + >>> from transformers import AutoProcessor, PaddleOCRVLForConditionalGeneration + + >>> model = PaddleOCRVLForConditionalGeneration.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16") + >>> processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL") + + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg", + }, + {"type": "text", "text": "OCR:"}, + ], + } + ] + + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ).to(model.device) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=1024) + >>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + >>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(output_text) + ``` + """ + outputs: PaddleOCRVLModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + image_grid_thw=image_grid_thw, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + return_dict=True, + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return PaddleOCRVLCausalLMOutputWithPast( loss=loss, @@ -1482,7 +1505,7 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - rope_deltas=self.rope_deltas, + rope_deltas=outputs.rope_deltas, ) def prepare_inputs_for_generation( @@ -1523,13 +1546,13 @@ def prepare_inputs_for_generation( if cache_position[0] != 0: model_inputs["pixel_values"] = None - model_inputs["pixel_values_videos"] = None return model_inputs def _get_image_nums_and_video_nums( self, input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Get the number of images and videos for each sample to calculate the separation length of the sample tensor. @@ -1547,10 +1570,31 @@ def _get_image_nums_and_video_nums( video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id - vision_start_mask = input_ids == vision_start_token_id + if inputs_embeds is not None: + vision_start_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + vision_start_mask = input_ids == vision_start_token_id + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) - image_mask = input_ids == image_token_id - video_mask = input_ids == video_token_id image_nums = torch.sum(vision_first_mask & image_mask, dim=1) video_nums = torch.sum(vision_first_mask & video_mask, dim=1) @@ -1571,18 +1615,14 @@ def _expand_inputs_for_generation( if expand_size == 1: return input_ids, model_kwargs - visual_keys = [ - "pixel_values", - "image_grid_thw", - "pixel_values_videos", - "video_grid_thw", - "second_per_grid_ts", - ] + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] def _expand_dict_for_generation_visual(dict_to_expand): image_grid_thw = model_kwargs.get("image_grid_thw", None) video_grid_thw = model_kwargs.get("video_grid_thw", None) - image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) def _repeat_interleave_samples(x, lengths, repeat_times): samples = torch.split(x, lengths) @@ -1617,14 +1657,9 @@ def _repeat_interleave_samples(x, lengths, repeat_times): dict_to_expand[key], lengths=lengths, repeat_times=expand_size ) elif key == "second_per_grid_ts": - if not isinstance(dict_to_expand[key], list): - raise TypeError( - f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." - ) - tensor = torch.tensor(dict_to_expand[key]) - lengths = list(video_nums) - tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) - dict_to_expand[key] = tensor.tolist() + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size + ) return dict_to_expand def _expand_dict_for_generation(dict_to_expand): @@ -1638,10 +1673,7 @@ def _expand_dict_for_generation(dict_to_expand): dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) return dict_to_expand - # input_ids is required for expanding visual inputs - # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. - if input_ids is not None and input_ids.numel() != 0: - model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index e0935c4f57f9..e6903e564ace 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -18,22 +18,18 @@ # limitations under the License. import math -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Optional, Union import numpy as np import torch -from einops import rearrange +import torch.nn.functional as F from torch import nn -from torch.nn import CrossEntropyLoss -from torch.nn.init import _calculate_fan_in_and_fan_out from ...activations import GELUActivation from ...cache_utils import Cache from ...configuration_utils import PretrainedConfig -from ...generation import GenerationMixin from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format from ...image_utils import ( OPENAI_CLIP_MEAN, @@ -41,15 +37,17 @@ ChannelDimension, ImageInput, PILImageResampling, + SizeDict, get_image_size, infer_channel_dimension_format, is_scaled_image, make_list_of_images, to_numpy_array, ) -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput -from ...modeling_rope_utils import dynamic_rope_update, rope_config_validation -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...masking_utils import create_bidirectional_mask +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_rope_utils import RopeParameters, rope_config_validation +from ...modeling_utils import PreTrainedModel from ...models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor from ...processing_utils import ( ProcessingKwargs, @@ -57,24 +55,34 @@ Unpack, ) from ...tokenization_utils_base import PreTokenizedInput, TextInput -from ...utils import TransformersKwargs, auto_docstring, logging, torch_int -from ..llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaMLP, - LlamaModel, - LlamaRMSNorm, - LlamaRotaryEmbedding, +from ...utils import TensorType, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int +from ..ernie4_5.modeling_ernie4_5 import ( + Ernie4_5DecoderLayer, + Ernie4_5MLP, + Ernie4_5Model, + Ernie4_5RMSNorm, +) +from ..qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniAttention, +) +from ..qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig +from ..qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLCausalLMOutputWithPast, + Qwen2VLForConditionalGeneration, + Qwen2VLModel, + Qwen2VLModelOutputWithPast, + Qwen2VLRotaryEmbedding, + VisionRotaryEmbedding, ) from ..siglip.modeling_siglip import ( - SiglipAttention, - SiglipEncoder, - SiglipEncoderLayer, SiglipMLP, SiglipMultiheadAttentionPoolingHead, SiglipVisionEmbeddings, - SiglipVisionModel, - SiglipVisionTransformer, +) +from ..video_llama_3.modeling_video_llama_3 import ( + VideoLlama3VisionAttention, + VideoLlama3VisionEncoder, + VideoLlama3VisionEncoderLayer, ) @@ -179,13 +187,17 @@ def __init__( self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD - self.min_pixels = min_pixels - self.max_pixels = max_pixels + if min_pixels is not None: + size["shortest_edge"] = min_pixels + if max_pixels is not None: + size["longest_edge"] = max_pixels + self.min_pixels = size["shortest_edge"] + self.max_pixels = size["longest_edge"] + self.size = size self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.merge_size = merge_size self.do_convert_rgb = do_convert_rgb - self.size = size def _preprocess( self, @@ -322,6 +334,138 @@ def _preprocess( return flatten_patches, (grid_t, grid_h, grid_w) +class PaddleOCRVLImageProcessorFast(BaseImageProcessorFast): + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: bool = True, + min_pixels: int = 28 * 28 * 130, + max_pixels: int = 28 * 28 * 1280, + patch_size: int = 14, + temporal_patch_size: int = 1, + merge_size: int = 2, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if size is not None and ("shortest_edge" not in size or "longest_edge" not in size): + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + else: + size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280} + # backward compatibility: override size with min_pixels and max_pixels if they are provided + if min_pixels is not None: + size["shortest_edge"] = min_pixels + if max_pixels is not None: + size["longest_edge"] = max_pixels + self.min_pixels = size["shortest_edge"] + self.max_pixels = size["longest_edge"] + self.size = size + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.merge_size = merge_size + self.do_convert_rgb = do_convert_rgb + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ): + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + height, width = stacked_images.shape[-2:] + if do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=self.patch_size * self.merge_size, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + stacked_images = self.resize( + image=stacked_images, + size=SizeDict(height=resized_height, width=resized_width), + interpolation=interpolation, + ) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + processed_grids = {} + for shape, stacked_images in grouped_images.items(): + resized_height, resized_width = stacked_images.shape[-2:] + # Fused rescale and normalize + patches = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + + if patches.ndim == 4: + # add a temporal dimension if we have images + patches = patches.unsqueeze(1) + if patches.shape[1] % self.temporal_patch_size != 0: + repeats = patches[:, -1:].repeat(1, self.temporal_patch_size - 1, 1, 1, 1) + patches = torch.cat([patches, repeats], dim=1) + + batch_size, grid_t, channel = patches.shape[:3] + grid_t = grid_t // self.temporal_patch_size + grid_h, grid_w = ( + resized_height // self.patch_size, + resized_width // self.patch_size, + ) + patches = patches.view( + batch_size, + grid_t, + self.temporal_patch_size, + channel, + grid_h, + self.patch_size, + grid_w, + self.patch_size, + ) + patches = patches.permute(0, 1, 4, 6, 3, 2, 5, 7) + flatten_patches = patches.reshape( + batch_size, grid_t * grid_h * grid_w, channel, self.patch_size, self.patch_size + ) + + processed_images_grouped[shape] = flatten_patches + processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_grids = reorder_images(processed_grids, grouped_images_index) + pixel_values = torch.cat(processed_images, dim=0) + image_grid_thw = torch.tensor(processed_grids) + + return BatchFeature( + data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors + ) + + class PaddleOCRVLProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { @@ -332,30 +476,17 @@ class PaddleOCRVLProcessorKwargs(ProcessingKwargs, total=False): class PaddleOCRVLProcessor(ProcessorMixin): r""" - [`PaddleOCRVLProcessor`] offers all the functionalities of [`PaddleOCRVLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`PaddleOCRVLProcessor`] offers all the functionalities of [`PaddleOCRVLImageProcessor`] and [`LLamaTokenizerFast`]. See the [`~PaddleOCRVLProcessor.__call__`] and [`~PaddleOCRVLProcessor.decode`] for more information. Args: image_processor ([`PaddleOCRVLImageProcessor`], *optional*): The image processor is a required input. - tokenizer ([`Qwen2TokenizerFast`], *optional*): + tokenizer ([`LLamaTokenizerFast`], *optional*): The tokenizer is a required input. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. """ - attributes = ["image_processor", "tokenizer"] - valid_kwargs = [ - "chat_template", - "image_std", - "min_pixels", - "image_mean", - "merge_size", - "image_processor_type", - "temporal_patch_size", - "patch_size", - "max_pixels", - ] - image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" @@ -370,11 +501,6 @@ def __call__( **kwargs: Unpack[PaddleOCRVLProcessorKwargs], ) -> BatchFeature: """ - Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` - and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode - the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to - PaddleOCRVLImageProcessor's [`~PaddleOCRVLImageProcessor.__call__`] if `vision_infos` is not `None`. - Args: images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch @@ -407,8 +533,7 @@ def __call__( ) if images is not None: - image_inputs = self.image_processor(images=images, return_tensors="pt") - image_inputs["pixel_values"] = image_inputs["pixel_values"] + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] else: @@ -418,6 +543,8 @@ def __call__( if not isinstance(text, list): text = [text] + text = text.copy() + if image_grid_thw is not None: index = 0 for i in range(len(text)): @@ -440,8 +567,8 @@ def __call__( return BatchFeature(data={**text_inputs, **image_inputs}) -class PaddleOCRVisionConfig(PretrainedConfig): - model_type = "paddleocr_vl" +class PaddleOCRVLVisionConfig(PretrainedConfig): + model_type = "paddleocr_vl_vision" base_config_key = "vision_config" def __init__( @@ -478,7 +605,7 @@ def __init__( self.tokens_per_second = tokens_per_second -class PaddleOCRVLConfig(PretrainedConfig): +class PaddleOCRVLTextConfig(PretrainedConfig): """ Configuration class. @@ -486,11 +613,8 @@ class PaddleOCRVLConfig(PretrainedConfig): It inherits from PretrainedConfig and can be used to control model outputs. """ - model_type = "paddleocr_vl" - keys_to_ignore_at_inference = ["past_key_values"] - sub_configs = {"vision_config": PaddleOCRVisionConfig} + model_type = "paddleocr_vl_text" - # Default tensor parallel plan for base model `Qwen3` base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", @@ -514,9 +638,6 @@ def __init__( max_position_embeddings=32768, num_hidden_layers=2, num_attention_heads=2, - image_token_id=101304, - video_token_id=101305, - vision_start_token_id=101306, rms_norm_eps=1e-6, use_cache=False, use_flash_attention=False, @@ -535,8 +656,7 @@ def __init__( num_key_value_heads=None, max_sequence_length=None, tie_word_embeddings=False, - vision_config=None, - rope_scaling=None, + rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, **kwargs, ): """ @@ -574,10 +694,6 @@ def __init__( eos_token_id=eos_token_id, **kwargs, ) - if isinstance(vision_config, dict): - self.vision_config = self.sub_configs["vision_config"](**vision_config) - elif vision_config is None: - self.vision_config = self.sub_configs["vision_config"]() self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size @@ -590,9 +706,6 @@ def __init__( self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id - self.image_token_id = image_token_id - self.video_token_id = video_token_id - self.vision_start_token_id = vision_start_token_id self.head_dim = head_dim self.hidden_act = hidden_act self.sliding_window = None @@ -606,20 +719,24 @@ def __init__( self.compression_ratio = compression_ratio self.num_key_value_heads = num_key_value_heads self.max_sequence_length = max_sequence_length - self.rope_scaling = rope_scaling - if self.rope_scaling is not None and "type" in self.rope_scaling: - if self.rope_scaling["type"] == "mrope": - self.rope_scaling["type"] = "default" - self.rope_scaling["rope_type"] = self.rope_scaling["type"] + # Try to set `rope_scaling` if available, otherwise use `rope_parameters` + rope_scaling = kwargs.pop("rope_scaling", None) + self.rope_parameters = rope_scaling or rope_parameters + if self.rope_parameters is not None and self.rope_parameters["rope_type"] == "mrope": + self.rope_parameters["rope_type"] = "default" rope_config_validation(self, ignore_keys={"mrope_section"}) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) -class Projector(nn.Module): - def __init__(self, text_config: PaddleOCRVLConfig, vision_config: PaddleOCRVisionConfig): +class PaddleOCRVLConfig(Qwen2VLConfig): + pass + + +class PaddleOCRProjector(nn.Module): + def __init__(self, config: PaddleOCRVLConfig): super().__init__() - self.text_config = text_config - self.vision_config = vision_config + self.text_config = config.text_config + self.vision_config = config.vision_config self.merge_kernel_size = (2, 2) self.hidden_size = self.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1] @@ -629,217 +746,68 @@ def __init__(self, text_config: PaddleOCRVLConfig, vision_config: PaddleOCRVisio self.act = GELUActivation() self.linear_2 = nn.Linear(self.hidden_size, self.text_config.hidden_size, bias=True) - def forward(self, image_features: torch.Tensor, image_grid_thw: list[tuple[int, int, int]]) -> torch.Tensor: + def forward(self, image_features: torch.Tensor, image_grid_thw: torch.Tensor) -> torch.Tensor: + image_features_chunks = image_features.split(image_grid_thw.prod(dim=1).tolist(), dim=0) m1, m2 = self.merge_kernel_size processed_features = [] - for image_feature, image_grid in zip(image_features, image_grid_thw): + for image_feature, image_grid in zip(image_features_chunks, image_grid_thw): image_feature = self.pre_norm(image_feature) t, h, w = image_grid + d = image_feature.shape[-1] + h_block = h // m1 + w_block = w // m2 + + image_feature = image_feature.reshape(t, h_block, m1, w_block, m2, d) + image_feature = image_feature.reshape((t * h_block * w_block), (m1 * m2 * d)) - image_feature = rearrange( - image_feature, "(t h p1 w p2) d -> (t h w) (p1 p2 d)", t=t, h=h // m1, p1=m1, w=w // m2, p2=m2 - ) hidden_states = self.linear_1(image_feature) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) processed_features.append(hidden_states) - return processed_features - + return torch.cat(processed_features, dim=0) -def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). - Explanation: - Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding - sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For - vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. - Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. - For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, - height and width) of text embedding is always the same, so the text embedding rotary position embedding has no - difference with modern LLMs. +class PaddleOCRVisionRotaryEmbedding(VisionRotaryEmbedding): + pass - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - mrope_section = mrope_section * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed +class PaddleOCRRotaryEmbedding(Qwen2VLRotaryEmbedding): + pass -class PaddleOCRRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: +class PaddleOCRMLP(Ernie4_5MLP): + def __init__(self, config: PaddleOCRVLTextConfig): super().__init__() - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs -class Ernie4_5RotaryEmbedding(LlamaRotaryEmbedding): - def __init__(self, config: PaddleOCRVLConfig, device=None): +class PaddleOCRAttention(Qwen2_5OmniAttention): + def __init__(self, config: PaddleOCRVLConfig, layer_idx: Optional[int] = None): super().__init__() - @staticmethod - def compute_default_rope_parameters( - config: Optional[PaddleOCRVLConfig] = None, - device: Optional["torch.device"] = None, - seq_len: Optional[int] = None, - ) -> tuple["torch.Tensor", float]: - base = config.rope_parameters["rope_theta"] - dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) - ) - return inv_freq, attention_factor - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) - position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class Ernie4_5MLP(LlamaMLP): - def __init__(self, config: PaddleOCRVLConfig): - super().__init__(config) - - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) - - -class Ernie4_5Attention(LlamaAttention): - def __init__(self, config: PaddleOCRVLConfig, layer_idx: int): - super().__init__(config, layer_idx) - self.attention_dropout = 0.0 - self.rope_scaling = config.rope_scaling - - self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) - self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) - self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) - self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.Tensor, torch.Tensor]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - if "position_ids" in kwargs and kwargs["position_ids"] is not None: - position_ids = kwargs["position_ids"] - if position_ids.dim() == 3 and position_ids.shape[0] > 1: - kwargs["position_ids"] = position_ids[0:1] - - cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] - ) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias) -class Ernie4_5RMSNorm(LlamaRMSNorm): +class PaddleOCRRMSNorm(Ernie4_5RMSNorm): pass -class Ernie4_5DecoderLayer(LlamaDecoderLayer): - def __init__(self, config: PaddleOCRVLConfig, layer_idx: int): - super().__init__(config) - - self.self_attn = Ernie4_5Attention(config=config, layer_idx=layer_idx) - - self.mlp = Ernie4_5MLP(config) - self.input_layernorm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) +class PaddleOCRDecoderLayer(Ernie4_5DecoderLayer): + def __init__(self, config: PaddleOCRVLTextConfig, layer_idx: int): + super().__init__() @auto_docstring -class Ernie4_5PreTrainedModel(PreTrainedModel): +class PaddleOCRVLPreTrainedModel(PreTrainedModel): config: PaddleOCRVLConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Ernie4_5DecoderLayer"] + _no_split_modules = ["PaddleOCRDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True @@ -847,312 +815,61 @@ class Ernie4_5PreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True _supports_attention_backend = True - _can_record_outputs = { - "hidden_states": Ernie4_5DecoderLayer, - "attentions": Ernie4_5Attention, - } -class Ernie4_5Model(LlamaModel): - def __init__(self, config: PaddleOCRVLConfig): +class PaddleOCRVLTextModel(PaddleOCRVLPreTrainedModel, Ernie4_5Model): + def __init__(self, config: PaddleOCRVLTextConfig): super().__init__(config) - self.layers = nn.ModuleList( - [Ernie4_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Ernie4_5RotaryEmbedding(config=config) -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision( - q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - orig_q_dtype = q.dtype - orig_k_dtype = k.dtype - q, k = q.float(), k.float() - cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - q_embed = q_embed.to(orig_q_dtype) - k_embed = k_embed.to(orig_k_dtype) - return q_embed, k_embed - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs: Unpack[TransformersKwargs], -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -def paddleocr_vl_eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - - attn_output = torch.matmul(attn_weights, value) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -def _trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - logger.warning_once( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - - -def trunc_normal_tf_( - tensor: torch.Tensor, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -) -> torch.Tensor: - """Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \\leq \text{mean} \\leq b`. - - NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the - bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 - and the result is subsequently scaled and shifted by the mean and std args. - - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - """ - with torch.no_grad(): - _trunc_normal_(tensor, 0, 1.0, a, b) - tensor.mul_(std).add_(mean) - - -def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - if mode == "fan_in": - denom = fan_in - elif mode == "fan_out": - denom = fan_out - elif mode == "fan_avg": - denom = (fan_in + fan_out) / 2 - - variance = scale / denom - - if distribution == "truncated_normal": - # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) - elif distribution == "normal": - with torch.no_grad(): - tensor.normal_(std=math.sqrt(variance)) - elif distribution == "uniform": - bound = math.sqrt(3 * variance) - with torch.no_grad(): - tensor.uniform_(-bound, bound) - else: - raise ValueError(f"invalid distribution {distribution}") - +class PaddleOCRVisionModel(PaddleOCRVLPreTrainedModel): + config: PaddleOCRVLVisionConfig + main_input_name = "pixel_values" + input_modalities = "image" -def lecun_normal_(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") - - -def default_flax_embed_init(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="normal") - - -@auto_docstring -class PaddleOCRPreTrainedModel(PreTrainedModel): - config_class = PaddleOCRVLConfig - base_model_prefix = "PaddleOCR" - supports_gradient_checkpointing = True + def __init__(self, config: PaddleOCRVLVisionConfig): + super().__init__(config) - _no_split_modules = [ - "PaddleOCRTextEmbeddings", - "PaddleOCREncoderLayer", - "PaddleOCRVisionEmbeddings", - "PaddleOCRMultiheadAttentionPoolingHead", - ] - _supports_flash_attn_2 = True - _supports_sdpa = True + self.vision_model = PaddleOCRVisionTransformer(config) - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, PaddleOCRVisionEmbeddings): - width = ( - self.config.vision_config.hidden_size - if isinstance(self.config, PaddleOCRVLConfig) - else self.config.hidden_size - ) - nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) - elif isinstance(module, nn.Embedding): - default_flax_embed_init(module.weight) - elif isinstance(module, PaddleOCRAttention): - nn.init.xavier_uniform_(module.q_proj.weight) - nn.init.xavier_uniform_(module.k_proj.weight) - nn.init.xavier_uniform_(module.v_proj.weight) - nn.init.xavier_uniform_(module.out_proj.weight) - nn.init.zeros_(module.q_proj.bias) - nn.init.zeros_(module.k_proj.bias) - nn.init.zeros_(module.v_proj.bias) - nn.init.zeros_(module.out_proj.bias) - elif isinstance(module, PaddleOCRMLP): - nn.init.xavier_uniform_(module.fc1.weight) - nn.init.xavier_uniform_(module.fc2.weight) - nn.init.normal_(module.fc1.bias, std=1e-6) - nn.init.normal_(module.fc2.bias, std=1e-6) - elif isinstance(module, PaddleOCRMultiheadAttentionPoolingHead): - nn.init.xavier_uniform_(module.probe.data) - nn.init.xavier_uniform_(module.attention.in_proj_weight.data) - nn.init.zeros_(module.attention.in_proj_bias.data) - elif isinstance(module, (nn.Linear, nn.Conv2d)): - lecun_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - -class PaddleOCRVisionModel(SiglipVisionModel): - def __init__(self, config: PaddleOCRVisionConfig): - super().__init__(config) + # Initialize weights and apply final processing + self.post_init() def forward( self, pixel_values, - position_ids: Optional[torch.Tensor] = None, + cu_seqlens, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, ) -> BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. - position_ids (`torch.LongTensor` of shape `sequence_length`): - The position ids of the image. - image_grid_thw (`List[Tuple]`, *optional*): + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): + The cumulative sequence lengths of each image or video feature. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ return self.vision_model( pixel_values=pixel_values, - position_ids=position_ids, + cu_seqlens=cu_seqlens, image_grid_thw=image_grid_thw, ) class PaddleOCRVisionEmbeddings(SiglipVisionEmbeddings): - def __init__(self, config: PaddleOCRVisionConfig): + def __init__(self, config: PaddleOCRVLVisionConfig): super().__init__() - self.packing_position_embedding = nn.Embedding(32768, self.embed_dim) - - @staticmethod - def flatten_list(image_grid_thw): - tmp_image_grid_thw = [] - for image_grid in image_grid_thw: - if isinstance(image_grid, list): - tmp_image_grid_thw.extend(image_grid) - else: - tmp_image_grid_thw.append(image_grid) - return tmp_image_grid_thw - def interpolate_pos_encoding( - self, embeddings: torch.Tensor, height: int, width: int, is_after_patchify: bool = False - ) -> torch.Tensor: + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: num_positions = self.position_embedding.weight.shape[0] patch_pos_embed = self.position_embedding.weight.unsqueeze(0) dim = embeddings.shape[-1] - if is_after_patchify: - new_height = height - new_width = width - else: - new_height = height // self.patch_size - new_width = width // self.patch_size + new_height = height + new_width = width sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) @@ -1171,7 +888,6 @@ def interpolate_pos_encoding( def forward( self, pixel_values: torch.FloatTensor, - position_ids: Optional[torch.Tensor] = None, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, ) -> torch.Tensor: """ @@ -1180,148 +896,91 @@ def forward( The tensors corresponding to the input images. position_ids (`torch.LongTensor` of shape `sequence_length`): The position ids of the image. - image_grid_thw (`List[Tuple]`, *optional*): + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - if pixel_values.dim() == 5: - assert position_ids is not None - - batch_size, squence_len, channel, height, width = pixel_values.shape - target_dtype = self.patch_embedding.weight.dtype - pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w") - patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] - embeddings = patch_embeds.flatten(-2).squeeze(-1) - embeddings = rearrange(embeddings, "(b l) d -> b l d", b=batch_size, l=squence_len) - - if image_grid_thw is not None: - flatten_image_grid_thw = self.flatten_list(image_grid_thw) - assert batch_size == 1 - start = 0 - assert sum([np.prod(x) for x in flatten_image_grid_thw]) == embeddings.shape[1], ( - flatten_image_grid_thw, - embeddings.shape, - ) - embeddings = embeddings.squeeze(0) - tmp_embeddings = [] - for image_grid in image_grid_thw: - t, h, w = image_grid - end = start + t * h * w - image_embeddings = embeddings[start:end, :] - position_embedding = ( - self.interpolate_pos_encoding(image_embeddings, h, w, True).squeeze(0).repeat(t, 1) - ) - image_embeddings = image_embeddings + position_embedding - tmp_embeddings.append(image_embeddings) - start = end - embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0) - else: - embeddings = embeddings + self.packing_position_embedding(position_ids) - return embeddings - else: - raise NotImplementedError(str(pixel_values.shape)) - + batch_size, squence_len, channel, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + pixel_values = pixel_values.reshape(batch_size * squence_len, channel, height, width) + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(-2).squeeze(-1) + embeddings = embeddings.reshape(batch_size, squence_len, -1) + + assert batch_size == 1, ( + f"Batch size must be 1, but received {batch_size}. This model only processes one image at a time." + ) + start = 0 + embeddings = embeddings.squeeze(0) + tmp_embeddings = [] + for image_grid in image_grid_thw: + t, h, w = image_grid + end = start + t * h * w + image_embeddings = embeddings[start:end, :] + position_embedding = self.interpolate_pos_encoding(image_embeddings, h, w).squeeze(0).repeat(t, 1) + image_embeddings = image_embeddings + position_embedding + tmp_embeddings.append(image_embeddings) + start = end + embeddings = torch.concat(tmp_embeddings, dim=0) -class PaddleOCRAttention(SiglipAttention): - def __init__(self, config: PaddleOCRVisionConfig): - super().__init__() + return embeddings - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - batch_size, seq_length, embed_dim = hidden_states.shape - - queries = self.q_proj(hidden_states) - keys = self.k_proj(hidden_states) - values = self.v_proj(hidden_states) - - cos, sin = rope_emb - queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim) - keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim) - values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - queries, keys = apply_rotary_pos_emb_vision(queries, keys, cos, sin) - queries = queries.transpose(1, 2) - keys = keys.transpose(1, 2) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - queries, - keys, - values, - attention_mask, - is_causal=self.is_causal, - scaling=self.scale, - dropout=0.0 if not self.training else self.dropout, - ) - attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() - attn_output = self.out_proj(attn_output) +class PaddleOCRVisionAttention(VideoLlama3VisionAttention): + def __init__(self, config: PaddleOCRVLVisionConfig): + super().__init__() - return attn_output, attn_weights +class PaddleOCRVisionMLP(SiglipMLP): + def __init__(self, config: PaddleOCRVLVisionConfig): + super().__init__() -class PaddleOCRMLP(SiglipMLP): - pass +class PaddleOCRVisionMultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead): + def __init__(self, config: PaddleOCRVLVisionConfig): + super().__init__() -class PaddleOCRMultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead): - pass + self.mlp = PaddleOCRVisionMLP(config) -class PaddleOCREncoderLayer(SiglipEncoderLayer): - def __init__(self, config: PaddleOCRVisionConfig): +class PaddleOCRVisionEncoderLayer(VideoLlama3VisionEncoderLayer): + def __init__(self, config: PaddleOCRVLVisionConfig): super().__init__() -class PaddleOCREncoder(SiglipEncoder): - def __init__(self, config: PaddleOCRVisionConfig): +class PaddleOCRVisionEncoder(VideoLlama3VisionEncoder): + def __init__(self, config: PaddleOCRVLVisionConfig): super().__init__() embed_dim = config.hidden_size num_heads = config.num_attention_heads head_dim = embed_dim // num_heads - self.rotary_pos_emb = PaddleOCRRotaryEmbedding(head_dim // 2) - - @staticmethod - def flatten_list(image_grid_thw): - tmp_image_grid_thw = [] - for image_grid in image_grid_thw: - if isinstance(image_grid, list): - tmp_image_grid_thw.extend(image_grid) - else: - tmp_image_grid_thw.append(image_grid) - return tmp_image_grid_thw + self.rotary_pos_emb = PaddleOCRVisionRotaryEmbedding(head_dim // 2) def forward( self, inputs_embeds, + cu_seqlens, attention_mask: Optional[torch.Tensor] = None, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, ) -> BaseModelOutput: """ Args: + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): + The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - image_grid_thw (`List[Tuple]`, *optional*): + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ device = inputs_embeds.device hidden_states = inputs_embeds - attention_mask = attention_mask.to(inputs_embeds.dtype) if attention_mask is not None else None - flatten_image_grid_thw = self.flatten_list(image_grid_thw) - assert sum([np.prod(x) for x in flatten_image_grid_thw]) == hidden_states.shape[1], ( - flatten_image_grid_thw, - hidden_states.shape, + attention_mask = create_bidirectional_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, ) - split_hids = [] split_wids = [] - for t, h, w in flatten_image_grid_thw: + for t, h, w in image_grid_thw: image_pids = torch.arange(t * h * w, device=device) % (h * w) sample_hids = image_pids // w sample_wids = image_pids % w @@ -1332,16 +991,16 @@ def forward( pids = torch.stack([height_position_ids, width_position_ids], dim=-1) max_grid_size = pids.max() + 1 - rope_emb_max_grid = self.rotary_pos_emb(max_grid_size) - rope_emb = rope_emb_max_grid[pids].flatten(1) - rope_emb = rope_emb.repeat(1, 2) - rope_emb = (rope_emb.cos(), rope_emb.sin()) + rotary_embeddings_max_grid = self.rotary_pos_emb(max_grid_size) + rotary_embeddings = rotary_embeddings_max_grid[pids].flatten(1) + rotary_embeddings = rotary_embeddings.repeat(1, 2) + position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, - attention_mask, - rope_emb=rope_emb, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, ) return BaseModelOutput( @@ -1349,32 +1008,44 @@ def forward( ) -class PaddleOCRVisionTransformer(SiglipVisionTransformer): - def __init__(self, config: PaddleOCRVisionConfig): - super().__init__() +class PaddleOCRVisionTransformer(PaddleOCRVLPreTrainedModel): + def __init__(self, config: PaddleOCRVLVisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.embeddings = PaddleOCRVisionEmbeddings(config) + self.encoder = PaddleOCRVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = PaddleOCRVisionMultiheadAttentionPoolingHead(config) def forward( self, pixel_values, + cu_seqlens: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, ) -> BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`): The tensors corresponding to the input images. + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): + The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. position_ids (`torch.LongTensor` of shape `sequence_length`): The position ids of the image. - image_grid_thw (`List[Tuple]`, *optional*): + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - hidden_states = self.embeddings(pixel_values, position_ids=position_ids, image_grid_thw=image_grid_thw) + hidden_states = self.embeddings(pixel_values, image_grid_thw=image_grid_thw) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, + cu_seqlens=cu_seqlens, attention_mask=attention_mask, image_grid_thw=image_grid_thw, ) @@ -1390,225 +1061,80 @@ def forward( ) -@dataclass -class PaddleOCRVLCausalLMOutputWithPast(ModelOutput): - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[list[torch.FloatTensor]] = None - hidden_states: Optional[tuple[torch.FloatTensor]] = None - attentions: Optional[tuple[torch.FloatTensor]] = None - rope_deltas: Optional[torch.LongTensor] = None +class PaddleOCRVLModelOutputWithPast(Qwen2VLModelOutputWithPast): + pass -class PaddleOCRVLForConditionalGeneration(Ernie4_5PreTrainedModel, GenerationMixin): - def __init__(self, config): +class PaddleOCRVLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast): + pass + + +class PaddleOCRVLModel(Qwen2VLModel): + _checkpoint_conversion_mapping = {"^model": "language_model"} + _keys_to_ignore_on_load_unexpected = ["packing_position_embedding"] + + def __init__(self, config: PaddleOCRVLConfig): super().__init__(config) - self.mlp_AR = Projector(config, config.vision_config) self.visual = PaddleOCRVisionModel(config.vision_config) - self.model = Ernie4_5Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.projector = PaddleOCRProjector(config) + self.language_model = PaddleOCRVLTextModel(config.text_config) self.rope_deltas = None self.post_init() def get_input_embeddings(self): - return self.model.embed_tokens + return self.language_model.embed_tokens def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head + self.language_model.embed_tokens = value - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def get_rope_index( - self, - input_ids: Optional[torch.LongTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): """ - Calculate the 3D rope index based on image and video's temporal, height and width in LLM. - - Explanation: - Each embedding sequence contains vision embedding and text embedding or just contains text embedding. - - For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. - Examples: - input_ids: [T T T T T], here T is for text. - temporal position_ids: [0, 1, 2, 3, 4] - height position_ids: [0, 1, 2, 3, 4] - width position_ids: [0, 1, 2, 3, 4] - - For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part - and 1D rotary position embedding for text part. - Examples: - Temporal (Time): 3 patches, representing different segments of the video in time. - Height: 2 patches, dividing each frame vertically. - Width: 2 patches, dividing each frame horizontally. - We also have some important parameters: - fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. - tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. - temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. - interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. - input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. - vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] - vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] - vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] - text temporal position_ids: [101, 102, 103, 104, 105] - text height position_ids: [101, 102, 103, 104, 105] - text width position_ids: [101, 102, 103, 104, 105] - Here we calculate the text start position_ids as the max vision position_ids plus 1. + Encodes images into continuous embeddings that can be forwarded to the language model. Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): - The temporal, height and width of feature shape of each video in LLM. - second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): - The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - Returns: - position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) - mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) """ - spatial_merge_size = self.config.vision_config.spatial_merge_size - image_token_id = self.config.image_token_id - video_token_id = self.config.video_token_id - vision_start_token_id = self.config.vision_start_token_id - mrope_position_deltas = [] - if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): - total_input_ids = input_ids - if attention_mask is None: - attention_mask = torch.ones_like(total_input_ids) - position_ids = torch.ones( - 3, - input_ids.shape[0], - input_ids.shape[1], - dtype=input_ids.dtype, - device=input_ids.device, - ) - image_index, video_index = 0, 0 - attention_mask = attention_mask.to(total_input_ids.device) - for i, input_ids in enumerate(total_input_ids): - input_ids = input_ids[attention_mask[i] == 1] - image_nums, video_nums = 0, 0 - vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) - vision_tokens = input_ids[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - input_tokens = input_ids.tolist() - llm_pos_ids_list: list = [] - st = 0 - remain_images, remain_videos = image_nums, video_nums - for _ in range(image_nums + video_nums): - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - second_per_grid_t = 0 - image_index += 1 - remain_images -= 1 - ed = ed_image - - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - if second_per_grid_ts is not None: - second_per_grid_t = second_per_grid_ts[video_index] - else: - second_per_grid_t = 1.0 - video_index += 1 - remain_videos -= 1 - ed = ed_video - llm_grid_t, llm_grid_h, llm_grid_w = ( - t.item(), - h.item() // spatial_merge_size, - w.item() // spatial_merge_size, - ) - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - if torch.is_tensor(second_per_grid_t): - second_per_grid_t = second_per_grid_t.detach().item() - range_tensor = torch.arange(llm_grid_t).view(-1, 1) - expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) - - time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second - - time_tensor_long = time_tensor.long() - t_index = time_tensor_long.flatten() - - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() - llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) - mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) - mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) - return position_ids, mrope_position_deltas - else: - if attention_mask is not None: - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) - max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] - mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] - else: - position_ids = ( - torch.arange(input_ids.shape[1], device=input_ids.device) - .view(1, 1, -1) - .expand(3, input_ids.shape[0], -1) - ) - mrope_position_deltas = torch.zeros( - [input_ids.shape[0], 1], - device=input_ids.device, - dtype=input_ids.dtype, - ) + pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0) + cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=image_grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + vision_outputs = self.visual( + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + cu_seqlens=cu_seqlens, + ) + image_embeds = vision_outputs.last_hidden_state + return image_embeds - return position_ids, mrope_position_deltas + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + ): + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + return image_mask + @can_return_tuple def forward( self, input_ids: torch.LongTensor = None, @@ -1616,70 +1142,31 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, **kwargs, - ) -> Union[tuple, PaddleOCRVLCausalLMOutputWithPast]: + ) -> Union[tuple, PaddleOCRVLModelOutputWithPast]: r""" - Returns: + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.dtype) - pixel_values = pixel_values.unsqueeze(0) - ppocr_position_ids = [] - image_grid_hws = [] - - for idx, thw in enumerate(image_grid_thw): - thw_tuple = tuple(thw.detach().cpu().numpy().tolist()) - numel = np.prod(thw_tuple) - image_grid_hws.append(thw_tuple) - image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) - ppocr_position_ids.append(image_position_ids) - - ppocr_position_ids = torch.concat(ppocr_position_ids, dim=0).to(pixel_values.device) - - vision_outputs = self.visual( - pixel_values=pixel_values, - image_grid_thw=image_grid_hws, - position_ids=ppocr_position_ids, - ) - image_embeds = vision_outputs.last_hidden_state - - image_embeds = self.mlp_AR(image_embeds, image_grid_thw) - - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - # image_embeds is a list of tensor, each tensor is a image feature,I want to concat them all into a tensor - image_embeds = torch.cat(image_embeds, dim=0) - n_image_features = image_embeds.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) + inputs_embeds = self.language_model.embed_tokens(input_ids) - mask = input_ids == self.config.image_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - image_mask = mask_expanded.to(inputs_embeds.device) + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = self.projector(image_embeds, image_grid_thw) + image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - # position_ids = None # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only @@ -1689,11 +1176,9 @@ def forward( or (past_key_values is None or past_key_values.get_seq_length() == 0) ): position_ids, rope_deltas = self.get_rope_index( - input_ids, - image_grid_thw, - video_grid_thw, - second_per_grid_ts, - attention_mask, + input_ids=input_ids, + image_grid_thw=image_grid_thw, + attention_mask=attention_mask, ) self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids @@ -1711,7 +1196,7 @@ def forward( position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - outputs = self.model( + outputs = self.language_model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, @@ -1722,27 +1207,117 @@ def forward( **kwargs, ) - hidden_states = outputs[0] + output = PaddleOCRVLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + return output if return_dict else output.to_tuple() + + +class PaddleOCRVLForConditionalGeneration(Qwen2VLForConditionalGeneration): + _checkpoint_conversion_mapping = { + "^visual": "model.visual", + "^mlp_AR": "model.projector", + r"^model(?!(\.visual|\.projector))": "model.language_model", + } + _keys_to_ignore_on_load_unexpected = ["packing_position_embedding"] + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, PaddleOCRVLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + + Example: + + ```python + >>> from transformers import AutoProcessor, PaddleOCRVLForConditionalGeneration + + >>> model = PaddleOCRVLForConditionalGeneration.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16") + >>> processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL") + + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg", + }, + {"type": "text", "text": "OCR:"}, + ], + } + ] + + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ).to(model.device) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=1024) + >>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + >>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(output_text) + ``` + """ + outputs: PaddleOCRVLModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + image_grid_thw=image_grid_thw, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + return_dict=True, + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return PaddleOCRVLCausalLMOutputWithPast( loss=loss, @@ -1750,7 +1325,7 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - rope_deltas=self.rope_deltas, + rope_deltas=outputs.rope_deltas, ) def prepare_inputs_for_generation( @@ -1791,143 +1366,15 @@ def prepare_inputs_for_generation( if cache_position[0] != 0: model_inputs["pixel_values"] = None - model_inputs["pixel_values_videos"] = None return model_inputs - def _get_image_nums_and_video_nums( - self, - input_ids: Optional[torch.LongTensor], - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Get the number of images and videos for each sample to calculate the separation length of the sample tensor. - These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. - - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - - Returns: - image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) - video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) - """ - image_token_id = self.config.image_token_id - video_token_id = self.config.video_token_id - vision_start_token_id = self.config.vision_start_token_id - - vision_start_mask = input_ids == vision_start_token_id - vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) - image_mask = input_ids == image_token_id - video_mask = input_ids == video_token_id - image_nums = torch.sum(vision_first_mask & image_mask, dim=1) - video_nums = torch.sum(vision_first_mask & video_mask, dim=1) - - return image_nums, video_nums - - def _expand_inputs_for_generation( - self, - expand_size: int = 1, - is_encoder_decoder: bool = False, - input_ids: Optional[torch.LongTensor] = None, - **model_kwargs, - ) -> tuple[torch.LongTensor, dict[str, Any]]: - # Overwritten -- Support for expanding tensors without a batch size dimension - # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t - # pixel_values.shape[0] is sum(seqlen_images for samples) - # image_grid_thw.shape[0] is sum(num_images for samples) - - if expand_size == 1: - return input_ids, model_kwargs - - visual_keys = [ - "pixel_values", - "image_grid_thw", - "pixel_values_videos", - "video_grid_thw", - "second_per_grid_ts", - ] - - def _expand_dict_for_generation_visual(dict_to_expand): - image_grid_thw = model_kwargs.get("image_grid_thw", None) - video_grid_thw = model_kwargs.get("video_grid_thw", None) - image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) - - def _repeat_interleave_samples(x, lengths, repeat_times): - samples = torch.split(x, lengths) - repeat_args = [repeat_times] + [1] * (x.dim() - 1) - result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) - return result - - for key in dict_to_expand: - if key == "pixel_values": - # split images into samples - samples = torch.split(image_grid_thw, list(image_nums)) - # compute the sequence length of images for each sample - lengths = [torch.prod(sample, dim=1).sum() for sample in samples] - dict_to_expand[key] = _repeat_interleave_samples( - dict_to_expand[key], lengths=lengths, repeat_times=expand_size - ) - elif key == "image_grid_thw": - # get the num of images for each sample - lengths = list(image_nums) - dict_to_expand[key] = _repeat_interleave_samples( - dict_to_expand[key], lengths=lengths, repeat_times=expand_size - ) - elif key == "pixel_values_videos": - samples = torch.split(video_grid_thw, list(video_nums)) - lengths = [torch.prod(sample, dim=1).sum() for sample in samples] - dict_to_expand[key] = _repeat_interleave_samples( - dict_to_expand[key], lengths=lengths, repeat_times=expand_size - ) - elif key == "video_grid_thw": - lengths = list(video_nums) - dict_to_expand[key] = _repeat_interleave_samples( - dict_to_expand[key], lengths=lengths, repeat_times=expand_size - ) - elif key == "second_per_grid_ts": - if not isinstance(dict_to_expand[key], list): - raise TypeError( - f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." - ) - tensor = torch.tensor(dict_to_expand[key]) - lengths = list(video_nums) - tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) - dict_to_expand[key] = tensor.tolist() - return dict_to_expand - - def _expand_dict_for_generation(dict_to_expand): - for key in dict_to_expand: - if ( - key != "cache_position" - and dict_to_expand[key] is not None - and isinstance(dict_to_expand[key], torch.Tensor) - and key not in visual_keys - ): - dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) - return dict_to_expand - - # input_ids is required for expanding visual inputs - # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. - if input_ids is not None and input_ids.numel() != 0: - model_kwargs = _expand_dict_for_generation_visual(model_kwargs) - - if input_ids is not None: - input_ids = input_ids.repeat_interleave(expand_size, dim=0) - - model_kwargs = _expand_dict_for_generation(model_kwargs) - - if is_encoder_decoder: - if model_kwargs.get("encoder_outputs") is None: - raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") - model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) - - return input_ids, model_kwargs - __all__ = [ "PaddleOCRVLForConditionalGeneration", "PaddleOCRVLConfig", - "PaddleOCRVisionConfig", + "PaddleOCRVLTextConfig", "PaddleOCRVLImageProcessor", + "PaddleOCRVLImageProcessorFast", "PaddleOCRVLProcessor", ] diff --git a/src/transformers/models/paddleocr_vl/processing_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/processing_paddleocr_vl.py index aa2c58981ea0..4c89f85ff429 100644 --- a/src/transformers/models/paddleocr_vl/processing_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/processing_paddleocr_vl.py @@ -41,30 +41,17 @@ class PaddleOCRVLProcessorKwargs(ProcessingKwargs, total=False): class PaddleOCRVLProcessor(ProcessorMixin): r""" - [`PaddleOCRVLProcessor`] offers all the functionalities of [`PaddleOCRVLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`PaddleOCRVLProcessor`] offers all the functionalities of [`PaddleOCRVLImageProcessor`] and [`LLamaTokenizerFast`]. See the [`~PaddleOCRVLProcessor.__call__`] and [`~PaddleOCRVLProcessor.decode`] for more information. Args: image_processor ([`PaddleOCRVLImageProcessor`], *optional*): The image processor is a required input. - tokenizer ([`Qwen2TokenizerFast`], *optional*): + tokenizer ([`LLamaTokenizerFast`], *optional*): The tokenizer is a required input. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. """ - attributes = ["image_processor", "tokenizer"] - valid_kwargs = [ - "chat_template", - "image_std", - "min_pixels", - "image_mean", - "merge_size", - "image_processor_type", - "temporal_patch_size", - "patch_size", - "max_pixels", - ] - image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" @@ -79,11 +66,6 @@ def __call__( **kwargs: Unpack[PaddleOCRVLProcessorKwargs], ) -> BatchFeature: """ - Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` - and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode - the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to - PaddleOCRVLImageProcessor's [`~PaddleOCRVLImageProcessor.__call__`] if `vision_infos` is not `None`. - Args: images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch @@ -116,8 +98,7 @@ def __call__( ) if images is not None: - image_inputs = self.image_processor(images=images, return_tensors="pt") - image_inputs["pixel_values"] = image_inputs["pixel_values"] + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] else: @@ -127,6 +108,8 @@ def __call__( if not isinstance(text, list): text = [text] + text = text.copy() + if image_grid_thw is not None: index = 0 for i in range(len(text)): From 70c2338f24c62bca6e7537a261c6374c24aef3ab Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Wed, 26 Nov 2025 11:52:48 +0800 Subject: [PATCH 03/19] update --- .../models/paddleocr_vl/__init__.py | 6 +- .../configuration_paddleocr_vl.py | 264 +++++++++----- .../image_processing_paddleocr_vl.py | 38 +- .../image_processing_paddleocr_vl_fast.py | 13 +- .../paddleocr_vl/modeling_paddleocr_vl.py | 132 +++---- .../paddleocr_vl/modular_paddleocr_vl.py | 344 ++++-------------- 6 files changed, 307 insertions(+), 490 deletions(-) diff --git a/src/transformers/models/paddleocr_vl/__init__.py b/src/transformers/models/paddleocr_vl/__init__.py index 6deb0275de89..d6be277b3f4d 100644 --- a/src/transformers/models/paddleocr_vl/__init__.py +++ b/src/transformers/models/paddleocr_vl/__init__.py @@ -1,4 +1,5 @@ -# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved. +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import TYPE_CHECKING from ...utils import _LazyModule @@ -19,6 +21,8 @@ if TYPE_CHECKING: from .configuration_paddleocr_vl import * + from .image_processing_paddleocr_vl import * + from .image_processing_paddleocr_vl_fast import * from .modeling_paddleocr_vl import * from .processing_paddleocr_vl import * else: diff --git a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py index 80ba67984b19..b8edb01825a4 100644 --- a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py @@ -23,13 +23,62 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Optional -from ...configuration_utils import PreTrainedConfig, PretrainedConfig -from ...modeling_rope_utils import RopeParameters, rope_config_validation +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters, standardize_rope_params +from ...modeling_rope_utils import rope_config_validation as _rope_config_validation -class PaddleOCRVLVisionConfig(PretrainedConfig): +class PaddleOCRVisionConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PaddleOCRVisionModel`]. It is used to instantiate a + PaddleOCR vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the PaddleOCR + [google/paddle_o_c_r-base-patch16-224](https://huggingface.co/google/paddle_o_c_r-base-patch16-224) architecture. + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + Example: + + ```python + >>> from transformers import PaddleOCRVisionConfig, PaddleOCRVisionModel + + >>> # Initializing a PaddleOCRVisionConfig with google/paddle_o_c_r-base-patch16-224 style configuration + >>> configuration = PaddleOCRVisionConfig() + + >>> # Initializing a PaddleOCRVisionModel (with random weights) from the google/paddle_o_c_r-base-patch16-224 style configuration + >>> model = PaddleOCRVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "paddleocr_vl_vision" base_config_key = "vision_config" @@ -41,13 +90,12 @@ def __init__( num_attention_heads=12, num_channels=3, image_size=224, - patch_size=14, + patch_size=16, hidden_act="gelu_pytorch_tanh", layer_norm_eps=1e-6, attention_dropout=0.0, spatial_merge_size=2, temporal_patch_size=2, - tokens_per_second=2, **kwargs, ): super().__init__(**kwargs) @@ -64,19 +112,85 @@ def __init__( self.hidden_act = hidden_act self.spatial_merge_size = spatial_merge_size self.temporal_patch_size = temporal_patch_size - self.tokens_per_second = tokens_per_second -class PaddleOCRVLTextConfig(PretrainedConfig): - """ - Configuration class. +rope_config_validation = partial(_rope_config_validation, ignore_keys={"mrope_section"}) - This class stores the configuration of an Ernie model, defining the model architecture. - It inherits from PretrainedConfig and can be used to control model outputs. - """ - model_type = "paddleocr_vl_text" +class PaddleOCRTextConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PaddleOCRTextModel`]. It is used to instantiate an Ernie 4.5 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Ernie 4.5 0.3B. + e.g. [baidu/ERNIE-4.5-0.3B-PT](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT) + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 103424): + Vocabulary size of the Ernie 4.5 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`PaddleOCRTextModel`] + hidden_size (`int`, *optional*, defaults to 1024): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 18): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + use_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in any of the projections including mlp and attention for example. + head_dim (`int`, *optional*, defaults to 128): + The attention head dimension. If None, it will default to hidden_size // num_attention_heads + + ```python + >>> from transformers import PaddleOCRTextModel, PaddleOCRTextConfig + + >>> # Initializing a PaddleOCRText 0.3B style configuration + >>> configuration = PaddleOCRTextConfig() + >>> # Initializing a model from the 0.3B style configuration + >>> model = PaddleOCRTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "paddleocr_vl_text" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `PaddleOCRTextModel` base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", @@ -94,100 +208,60 @@ class PaddleOCRVLTextConfig(PretrainedConfig): def __init__( self, - vocab_size=32000, - hidden_size=768, - intermediate_size=11008, - max_position_embeddings=32768, - num_hidden_layers=2, - num_attention_heads=2, - rms_norm_eps=1e-6, - use_cache=False, - use_flash_attention=False, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - head_dim=128, - hidden_act="silu", - use_bias=False, - rope_theta=10000, - weight_share_add_bias=True, - ignored_index=-100, - attention_probs_dropout_prob=0.0, - hidden_dropout_prob=0.0, - compression_ratio: float = 1.0, - num_key_value_heads=None, - max_sequence_length=None, - tie_word_embeddings=False, + vocab_size: Optional[int] = 103424, + hidden_size: Optional[int] = 1024, + intermediate_size: Optional[int] = 3072, + num_hidden_layers: Optional[int] = 18, + num_attention_heads: Optional[int] = 16, + num_key_value_heads: Optional[int] = 2, + hidden_act: Optional[str] = "silu", + max_position_embeddings: Optional[int] = 131072, + initializer_range: Optional[float] = 0.02, + rms_norm_eps: Optional[int] = 1e-05, + use_cache: Optional[int] = True, + pad_token_id: Optional[int] = 0, + bos_token_id: Optional[int] = 1, + eos_token_id: Optional[int] = 2, + tie_word_embeddings: Optional[bool] = True, rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, + use_bias: Optional[bool] = False, + head_dim: Optional[int] = 128, **kwargs, ): - """ - Initialize configuration with default or specified parameters. - - Args: - vocab_size (int): Size of the vocabulary (number of unique tokens) - hidden_size (int): Dimensionality of the encoder layers and the pooler layer - intermediate_size (int): Dimensionality of the "intermediate" (feed-forward) layer - max_position_embeddings (int): Maximum sequence length the model can handle - num_hidden_layers (int): Number of hidden layers in the Transformer encoder - num_attention_heads (int): Number of attention heads for each attention layer - rms_norm_eps (float): The epsilon used by the RMS normalization layers - use_cache (bool): Whether to use caching for faster generation (decoding) - use_flash_attention (bool): Whether to use FlashAttention for optimized attention computation - pad_token_id (int): Token ID used for padding sequences - bos_token_id (int): Token ID used for beginning-of-sequence - eos_token_id (int): Token ID used for end-of-sequence - use_bias (bool): Whether to use bias terms in linear layers - rope_theta (float): The base period of the RoPE embeddings - weight_share_add_bias (bool): Whether to share bias weights in certain layers - ignored_index (int): Target value that is ignored during loss computation - attention_probs_dropout_prob (float): Dropout probability for attention weights - hidden_dropout_prob (float): Dropout probability for hidden layers - compression_ratio (float): Ratio for KV cache compression (1.0 = no compression) - num_key_value_heads (int): Number of key/value heads (for Grouped Query Attention) - max_sequence_length (int): Maximum sequence length for positional embeddings - **kwargs: Additional keyword arguments passed to parent class - """ - - # Set default for tied embeddings if not specified. - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - **kwargs, - ) self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.max_position_embeddings = max_position_embeddings self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache - self.use_flash_attention = use_flash_attention - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.head_dim = head_dim - self.hidden_act = hidden_act - self.sliding_window = None - self.hidden_size = hidden_size self.use_bias = use_bias - self.weight_share_add_bias = weight_share_add_bias - self.rope_theta = rope_theta - self.ignored_index = ignored_index - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.hidden_dropout_prob = hidden_dropout_prob - self.compression_ratio = compression_ratio - self.num_key_value_heads = num_key_value_heads - self.max_sequence_length = max_sequence_length + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads # Try to set `rope_scaling` if available, otherwise use `rope_parameters` rope_scaling = kwargs.pop("rope_scaling", None) self.rope_parameters = rope_scaling or rope_parameters - if self.rope_parameters is not None and self.rope_parameters["rope_type"] == "mrope": - self.rope_parameters["rope_type"] = "default" - rope_config_validation(self, ignore_keys={"mrope_section"}) - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + # Validate the correctness of rotary position embeddings parameters + rope_theta = kwargs.get("rope_theta", 500000.0) + standardize_rope_params(self, rope_theta=rope_theta) + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) class PaddleOCRVLConfig(PreTrainedConfig): @@ -229,7 +303,7 @@ class PaddleOCRVLConfig(PreTrainedConfig): ```""" model_type = "paddleocr_vl" - sub_configs = {"vision_config": PaddleOCRVLVisionConfig, "text_config": PaddleOCRVLTextConfig} + sub_configs = {"vision_config": PaddleOCRVisionConfig, "text_config": PaddleOCRTextConfig} keys_to_ignore_at_inference = ["past_key_values"] def __init__( @@ -290,4 +364,4 @@ def __getattribute__(self, key): return super().__getattribute__(key) -__all__ = ["PaddleOCRVLConfig", "PaddleOCRVLTextConfig"] +__all__ = ["PaddleOCRVLConfig", "PaddleOCRVisionConfig", "PaddleOCRTextConfig"] diff --git a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py index 8cf26f0479e0..2f7c00a5a9df 100644 --- a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py @@ -77,8 +77,8 @@ def smart_resize( height: int, width: int, factor: int = 28, - min_pixels: int = 28 * 28 * 130, - max_pixels: int = 28 * 28 * 1280, + min_pixels: int = 384 * 384, + max_pixels: int = 1536 * 1536, ): if height < factor: width = round((width * factor) / height) @@ -110,34 +110,12 @@ class PaddleOCRVLImageProcessor(BaseImageProcessor): Constructs a PaddleOCRVL image processor that dynamically resizes images based on the original images. Args: - do_resize (`bool`, *optional*, defaults to `True`): - Whether to resize the image's (height, width) dimensions. - size (`Dict[str, int]`, *optional*, defaults to `None`): - Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present. - resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): - Resampling filter to use when resizing the image. - do_rescale (`bool`, *optional*, defaults to `True`): - Whether to rescale the image by the specified scale `rescale_factor`. - rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): - Scale factor to use if rescaling the image. - do_normalize (`bool`, *optional*, defaults to `True`): - Whether to normalize the image. - image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): - Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. - image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): - Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. - do_convert_rgb (`bool`, *optional*, defaults to `True`): - Whether to convert the image to RGB. - min_pixels (`int`, *optional*, defaults to `28 * 28 * 130`): + min_pixels (`int`, *optional*, defaults to `384 * 384`): The min pixels of the image to resize the image. - max_pixels (`int`, *optional*, defaults to `28 * 28 * 1670`): + max_pixels (`int`, *optional*, defaults to `1536 * 1536`): The max pixels of the image to resize the image. - patch_size (`int`, *optional*, defaults to 14): - The spacial patch size of the vision encoder. - temporal_patch_size (`int`, *optional*, defaults to 2): + temporal_patch_size (`int`, *optional*, defaults to 1): The temporal patch size of the vision encoder. - merge_size (`int`, *optional*, defaults to 2): - The merge size of the vision encoder to llm encoder. """ model_input_names = [ @@ -157,8 +135,8 @@ def __init__( image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, do_convert_rgb: bool = True, - min_pixels: int = 28 * 28 * 130, - max_pixels: int = 28 * 28 * 1280, + min_pixels: int = 384 * 384, + max_pixels: int = 1536 * 1536, patch_size: int = 14, temporal_patch_size: int = 1, merge_size: int = 2, @@ -177,6 +155,7 @@ def __init__( self.min_pixels = size["shortest_edge"] self.max_pixels = size["longest_edge"] self.size = size + self.do_resize = do_resize self.resample = resample self.do_rescale = do_rescale @@ -184,6 +163,7 @@ def __init__( self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.merge_size = merge_size diff --git a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py index 1a7cb510eea5..48d47d77c2d4 100644 --- a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +++ b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py @@ -39,8 +39,8 @@ def smart_resize( height: int, width: int, factor: int = 28, - min_pixels: int = 28 * 28 * 130, - max_pixels: int = 28 * 28 * 1280, + min_pixels: int = 384 * 384, + max_pixels: int = 1536 * 1536, ): if height < factor: width = round((width * factor) / height) @@ -79,8 +79,6 @@ def __init__( image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, do_convert_rgb: bool = True, - min_pixels: int = 28 * 28 * 130, - max_pixels: int = 28 * 28 * 1280, patch_size: int = 14, temporal_patch_size: int = 1, merge_size: int = 2, @@ -90,12 +88,7 @@ def __init__( if size is not None and ("shortest_edge" not in size or "longest_edge" not in size): raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") else: - size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280} - # backward compatibility: override size with min_pixels and max_pixels if they are provided - if min_pixels is not None: - size["shortest_edge"] = min_pixels - if max_pixels is not None: - size["longest_edge"] = max_pixels + size = {"shortest_edge": 384 * 384, "longest_edge": 1536 * 1536} self.min_pixels = size["shortest_edge"] self.max_pixels = size["longest_edge"] self.size = size diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 88a80d678a86..f6c214158a12 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -43,7 +43,7 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int from ...utils.generic import check_model_inputs -from .configuration_paddleocr_vl import PaddleOCRVLConfig, PaddleOCRVLTextConfig, PaddleOCRVLVisionConfig +from .configuration_paddleocr_vl import PaddleOCRTextConfig, PaddleOCRVisionConfig, PaddleOCRVLConfig logger = logging.get_logger(__name__) @@ -54,7 +54,7 @@ def __init__(self, config: PaddleOCRVLConfig): super().__init__() self.text_config = config.text_config self.vision_config = config.vision_config - self.merge_kernel_size = (2, 2) + self.merge_kernel_size = (self.vision_config.spatial_merge_size, self.vision_config.spatial_merge_size) self.hidden_size = self.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1] @@ -167,7 +167,7 @@ def forward(self, x, position_ids): class PaddleOCRMLP(nn.Module): - def __init__(self, config: PaddleOCRVLTextConfig): + def __init__(self, config: PaddleOCRTextConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -382,7 +382,7 @@ def extra_repr(self): class PaddleOCRDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: PaddleOCRVLTextConfig, layer_idx: int): + def __init__(self, config: PaddleOCRTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -442,8 +442,8 @@ class PaddleOCRVLPreTrainedModel(PreTrainedModel): @auto_docstring -class PaddleOCRVLTextModel(PaddleOCRVLPreTrainedModel): - def __init__(self, config: PaddleOCRVLTextConfig): +class PaddleOCRTextModel(PaddleOCRVLPreTrainedModel): + def __init__(self, config: PaddleOCRTextConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -521,11 +521,11 @@ def forward( class PaddleOCRVisionModel(PaddleOCRVLPreTrainedModel): - config: PaddleOCRVLVisionConfig + config: PaddleOCRVisionConfig main_input_name = "pixel_values" input_modalities = "image" - def __init__(self, config: PaddleOCRVLVisionConfig): + def __init__(self, config: PaddleOCRVisionConfig): super().__init__(config) self.vision_model = PaddleOCRVisionTransformer(config) @@ -535,8 +535,8 @@ def __init__(self, config: PaddleOCRVLVisionConfig): def forward( self, - pixel_values, - cu_seqlens, + pixel_values: torch.FloatTensor, + cu_seqlens: torch.Tensor, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, ) -> BaseModelOutputWithPooling: """ @@ -556,7 +556,7 @@ def forward( class PaddleOCRVisionEmbeddings(nn.Module): - def __init__(self, config: PaddleOCRVLVisionConfig): + def __init__(self, config: PaddleOCRVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -591,16 +591,13 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: dim = embeddings.shape[-1] - new_height = height - new_width = width - sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - size=(new_height, new_width), + size=(height, width), mode="bilinear", align_corners=False, ) @@ -617,8 +614,6 @@ def forward( Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. - position_ids (`torch.LongTensor` of shape `sequence_length`): - The position ids of the image. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ @@ -629,9 +624,6 @@ def forward( embeddings = patch_embeds.flatten(-2).squeeze(-1) embeddings = embeddings.reshape(batch_size, squence_len, -1) - assert batch_size == 1, ( - f"Batch size must be 1, but received {batch_size}. This model only processes one image at a time." - ) start = 0 embeddings = embeddings.squeeze(0) tmp_embeddings = [] @@ -665,7 +657,7 @@ def apply_rotary_pos_emb_vision( class PaddleOCRVisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: PaddleOCRVLVisionConfig): + def __init__(self, config: PaddleOCRVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -768,7 +760,7 @@ def forward( class PaddleOCRVisionMLP(nn.Module): - def __init__(self, config: PaddleOCRVLVisionConfig): + def __init__(self, config: PaddleOCRVisionConfig): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] @@ -782,33 +774,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class PaddleOCRVisionMultiheadAttentionPoolingHead(nn.Module): - """Multihead Attention Pooling.""" - - def __init__(self, config: PaddleOCRVLVisionConfig): - super().__init__() - - self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) - self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) - self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = PaddleOCRVisionMLP(config) - - def forward(self, hidden_state): - batch_size = hidden_state.shape[0] - probe = self.probe.repeat(batch_size, 1, 1) - - hidden_state = self.attention(probe, hidden_state, hidden_state)[0] - - residual = hidden_state - hidden_state = self.layernorm(hidden_state) - hidden_state = residual + self.mlp(hidden_state) - - return hidden_state[:, 0] - - class PaddleOCRVisionEncoderLayer(GradientCheckpointingLayer): - def __init__(self, config: PaddleOCRVLVisionConfig): + def __init__(self, config: PaddleOCRVisionConfig): super().__init__() self.embed_dim = config.hidden_size self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -858,7 +825,7 @@ class PaddleOCRVisionEncoder(nn.Module): config: PaddleOCRVisionConfig """ - def __init__(self, config: PaddleOCRVLVisionConfig): + def __init__(self, config: PaddleOCRVisionConfig): super().__init__() self.config = config self.layers = nn.ModuleList([PaddleOCRVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) @@ -873,13 +840,17 @@ def __init__(self, config: PaddleOCRVLVisionConfig): @auto_docstring def forward( self, - inputs_embeds, - cu_seqlens, + inputs_embeds: torch.FloatTensor, + cu_seqlens: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, ) -> BaseModelOutput: """ Args: + inputs_embeds (`torch.FloatTensor` of shape `(sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -925,7 +896,7 @@ def forward( class PaddleOCRVisionTransformer(PaddleOCRVLPreTrainedModel): - def __init__(self, config: PaddleOCRVLVisionConfig): + def __init__(self, config: PaddleOCRVisionConfig): super().__init__(config) self.config = config embed_dim = config.hidden_size @@ -933,13 +904,10 @@ def __init__(self, config: PaddleOCRVLVisionConfig): self.embeddings = PaddleOCRVisionEmbeddings(config) self.encoder = PaddleOCRVisionEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head - if self.use_head: - self.head = PaddleOCRVisionMultiheadAttentionPoolingHead(config) def forward( self, - pixel_values, + pixel_values: torch.FloatTensor, cu_seqlens: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, @@ -952,8 +920,6 @@ def forward( The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - position_ids (`torch.LongTensor` of shape `sequence_length`): - The position ids of the image. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ @@ -1036,12 +1002,12 @@ class PaddleOCRVLModel(PaddleOCRVLPreTrainedModel): _checkpoint_conversion_mapping = {"^model": "language_model"} # Reference: fix gemma3 grad acc #37208 accepts_loss_kwargs = False - _keys_to_ignore_on_load_unexpected = ["packing_position_embedding"] + _keys_to_ignore_on_load_unexpected = ["packing_position_embedding", "vision_model.head"] def __init__(self, config: PaddleOCRVLConfig): super().__init__(config) - self.visual = PaddleOCRVisionModel(config.vision_config) - self.language_model = PaddleOCRVLTextModel(config.text_config) + self.visual = PaddleOCRVisionModel._from_config(config.vision_config) + self.language_model = PaddleOCRTextModel._from_config(config.text_config) self.rope_deltas = None self.projector = PaddleOCRProjector(config) @@ -1253,29 +1219,32 @@ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Op cu_seqlens=cu_seqlens, ) image_embeds = vision_outputs.last_hidden_state + image_embeds = self.projector(image_embeds, image_grid_thw) return image_embeds def get_placeholder_mask( - self, - input_ids: torch.LongTensor, - inputs_embeds: torch.FloatTensor, - image_features: Optional[torch.FloatTensor] = None, + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor ): """ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is equal to the length of multimodal features. If the lengths are different, an error is raised. """ - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - n_image_features = image_features.shape[0] - if n_image_tokens != n_image_features: + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - mask = input_ids == self.config.image_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - image_mask = mask_expanded.to(inputs_embeds.device) - return image_mask + return special_image_mask @can_return_tuple def forward( @@ -1305,8 +1274,9 @@ def forward( inputs_embeds = self.language_model.embed_tokens(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw) - image_embeds = self.projector(image_embeds, image_grid_thw) + image_embeds = self.get_image_features(pixel_values, image_grid_thw).to( + inputs_embeds.device, inputs_embeds.dtype + ) image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) @@ -1358,7 +1328,7 @@ def forward( rope_deltas=self.rope_deltas, ) - return output if return_dict else output.to_tuple() + return output class PaddleOCRVLForConditionalGeneration(PaddleOCRVLPreTrainedModel, GenerationMixin): @@ -1368,7 +1338,7 @@ class PaddleOCRVLForConditionalGeneration(PaddleOCRVLPreTrainedModel, Generation r"^model(?!(\.visual|\.projector))": "model.language_model", } _tied_weights_keys = ["lm_head.weight"] - _keys_to_ignore_on_load_unexpected = ["packing_position_embedding"] + _keys_to_ignore_on_load_unexpected = ["packing_position_embedding", "vision_model.head"] def __init__(self, config): super().__init__(config) @@ -1433,8 +1403,6 @@ def forward( (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): - The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. @@ -1520,8 +1488,6 @@ def prepare_inputs_for_generation( pixel_values=None, pixel_values_videos=None, image_grid_thw=None, - video_grid_thw=None, - second_per_grid_ts=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -1536,8 +1502,6 @@ def prepare_inputs_for_generation( pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, use_cache=use_cache, **kwargs, ) @@ -1688,4 +1652,4 @@ def _expand_dict_for_generation(dict_to_expand): return input_ids, model_kwargs -__all__ = ["PaddleOCRVLForConditionalGeneration"] +__all__ = ["PaddleOCRVLForConditionalGeneration", "PaddleOCRTextModel", "PaddleOCRVisionModel"] diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index e6903e564ace..8df7bc64a23d 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -18,6 +18,7 @@ # limitations under the License. import math +from functools import partial from typing import Optional, Union import numpy as np @@ -27,7 +28,6 @@ from ...activations import GELUActivation from ...cache_utils import Cache -from ...configuration_utils import PretrainedConfig from ...image_processing_utils import BatchFeature from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format @@ -46,7 +46,7 @@ ) from ...masking_utils import create_bidirectional_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from ...modeling_rope_utils import RopeParameters, rope_config_validation +from ...modeling_rope_utils import rope_config_validation as _rope_config_validation from ...modeling_utils import PreTrainedModel from ...models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor from ...processing_utils import ( @@ -56,6 +56,7 @@ ) from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import TensorType, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int +from ..ernie4_5.configuration_ernie4_5 import Ernie4_5Config from ..ernie4_5.modeling_ernie4_5 import ( Ernie4_5DecoderLayer, Ernie4_5MLP, @@ -74,9 +75,9 @@ Qwen2VLRotaryEmbedding, VisionRotaryEmbedding, ) +from ..siglip.configuration_siglip import SiglipVisionConfig from ..siglip.modeling_siglip import ( SiglipMLP, - SiglipMultiheadAttentionPoolingHead, SiglipVisionEmbeddings, ) from ..video_llama_3.modeling_video_llama_3 import ( @@ -87,14 +88,15 @@ logger = logging.get_logger(__name__) +rope_config_validation = partial(_rope_config_validation, ignore_keys={"mrope_section"}) def smart_resize( height: int, width: int, factor: int = 28, - min_pixels: int = 28 * 28 * 130, - max_pixels: int = 28 * 28 * 1280, + min_pixels: int = 384 * 384, + max_pixels: int = 1536 * 1536, ): if height < factor: width = round((width * factor) / height) @@ -126,34 +128,12 @@ class PaddleOCRVLImageProcessor(Qwen2VLImageProcessor): Constructs a PaddleOCRVL image processor that dynamically resizes images based on the original images. Args: - do_resize (`bool`, *optional*, defaults to `True`): - Whether to resize the image's (height, width) dimensions. - size (`Dict[str, int]`, *optional*, defaults to `None`): - Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present. - resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): - Resampling filter to use when resizing the image. - do_rescale (`bool`, *optional*, defaults to `True`): - Whether to rescale the image by the specified scale `rescale_factor`. - rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): - Scale factor to use if rescaling the image. - do_normalize (`bool`, *optional*, defaults to `True`): - Whether to normalize the image. - image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): - Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. - image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): - Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. - do_convert_rgb (`bool`, *optional*, defaults to `True`): - Whether to convert the image to RGB. - min_pixels (`int`, *optional*, defaults to `28 * 28 * 130`): + min_pixels (`int`, *optional*, defaults to `384 * 384`): The min pixels of the image to resize the image. - max_pixels (`int`, *optional*, defaults to `28 * 28 * 1670`): + max_pixels (`int`, *optional*, defaults to `1536 * 1536`): The max pixels of the image to resize the image. - patch_size (`int`, *optional*, defaults to 14): - The spacial patch size of the vision encoder. - temporal_patch_size (`int`, *optional*, defaults to 2): + temporal_patch_size (`int`, *optional*, defaults to 1): The temporal patch size of the vision encoder. - merge_size (`int`, *optional*, defaults to 2): - The merge size of the vision encoder to llm encoder. """ model_input_names = [ @@ -172,32 +152,14 @@ def __init__( image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, do_convert_rgb: bool = True, - min_pixels: int = 28 * 28 * 130, - max_pixels: int = 28 * 28 * 1280, + min_pixels: int = 384 * 384, + max_pixels: int = 1536 * 1536, patch_size: int = 14, temporal_patch_size: int = 1, merge_size: int = 2, **kwargs, ) -> None: - super().__init__(**kwargs) - self.do_resize = do_resize - self.resample = resample - self.do_rescale = do_rescale - self.rescale_factor = rescale_factor - self.do_normalize = do_normalize - self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN - self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD - if min_pixels is not None: - size["shortest_edge"] = min_pixels - if max_pixels is not None: - size["longest_edge"] = max_pixels - self.min_pixels = size["shortest_edge"] - self.max_pixels = size["longest_edge"] - self.size = size - self.patch_size = patch_size - self.temporal_patch_size = temporal_patch_size - self.merge_size = merge_size - self.do_convert_rgb = do_convert_rgb + super().__init__() def _preprocess( self, @@ -346,8 +308,6 @@ def __init__( image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, do_convert_rgb: bool = True, - min_pixels: int = 28 * 28 * 130, - max_pixels: int = 28 * 28 * 1280, patch_size: int = 14, temporal_patch_size: int = 1, merge_size: int = 2, @@ -357,12 +317,7 @@ def __init__( if size is not None and ("shortest_edge" not in size or "longest_edge" not in size): raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") else: - size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280} - # backward compatibility: override size with min_pixels and max_pixels if they are provided - if min_pixels is not None: - size["shortest_edge"] = min_pixels - if max_pixels is not None: - size["longest_edge"] = max_pixels + size = {"shortest_edge": 384 * 384, "longest_edge": 1536 * 1536} self.min_pixels = size["shortest_edge"] self.max_pixels = size["longest_edge"] self.size = size @@ -567,169 +522,27 @@ def __call__( return BatchFeature(data={**text_inputs, **image_inputs}) -class PaddleOCRVLVisionConfig(PretrainedConfig): +class PaddleOCRVisionConfig(SiglipVisionConfig): model_type = "paddleocr_vl_vision" base_config_key = "vision_config" def __init__( self, - hidden_size=768, - intermediate_size=3072, - num_hidden_layers=12, - num_attention_heads=12, - num_channels=3, - image_size=224, - patch_size=14, - hidden_act="gelu_pytorch_tanh", - layer_norm_eps=1e-6, - attention_dropout=0.0, spatial_merge_size=2, temporal_patch_size=2, - tokens_per_second=2, - **kwargs, + **super_kwargs, ): - super().__init__(**kwargs) - - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_channels = num_channels - self.patch_size = patch_size - self.image_size = image_size - self.attention_dropout = attention_dropout - self.layer_norm_eps = layer_norm_eps - self.hidden_act = hidden_act + super().__init__() self.spatial_merge_size = spatial_merge_size self.temporal_patch_size = temporal_patch_size - self.tokens_per_second = tokens_per_second -class PaddleOCRVLTextConfig(PretrainedConfig): - """ - Configuration class. - - This class stores the configuration of an Ernie model, defining the model architecture. - It inherits from PretrainedConfig and can be used to control model outputs. - """ - +class PaddleOCRTextConfig(Ernie4_5Config): model_type = "paddleocr_vl_text" - base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", - "layers.*.mlp.down_proj": "rowwise", - } - base_model_pp_plan = { - "embed_tokens": (["input_ids"], ["inputs_embeds"]), - "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), - "norm": (["hidden_states"], ["hidden_states"]), - } - - def __init__( - self, - vocab_size=32000, - hidden_size=768, - intermediate_size=11008, - max_position_embeddings=32768, - num_hidden_layers=2, - num_attention_heads=2, - rms_norm_eps=1e-6, - use_cache=False, - use_flash_attention=False, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - head_dim=128, - hidden_act="silu", - use_bias=False, - rope_theta=10000, - weight_share_add_bias=True, - ignored_index=-100, - attention_probs_dropout_prob=0.0, - hidden_dropout_prob=0.0, - compression_ratio: float = 1.0, - num_key_value_heads=None, - max_sequence_length=None, - tie_word_embeddings=False, - rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, - **kwargs, - ): - """ - Initialize configuration with default or specified parameters. - - Args: - vocab_size (int): Size of the vocabulary (number of unique tokens) - hidden_size (int): Dimensionality of the encoder layers and the pooler layer - intermediate_size (int): Dimensionality of the "intermediate" (feed-forward) layer - max_position_embeddings (int): Maximum sequence length the model can handle - num_hidden_layers (int): Number of hidden layers in the Transformer encoder - num_attention_heads (int): Number of attention heads for each attention layer - rms_norm_eps (float): The epsilon used by the RMS normalization layers - use_cache (bool): Whether to use caching for faster generation (decoding) - use_flash_attention (bool): Whether to use FlashAttention for optimized attention computation - pad_token_id (int): Token ID used for padding sequences - bos_token_id (int): Token ID used for beginning-of-sequence - eos_token_id (int): Token ID used for end-of-sequence - use_bias (bool): Whether to use bias terms in linear layers - rope_theta (float): The base period of the RoPE embeddings - weight_share_add_bias (bool): Whether to share bias weights in certain layers - ignored_index (int): Target value that is ignored during loss computation - attention_probs_dropout_prob (float): Dropout probability for attention weights - hidden_dropout_prob (float): Dropout probability for hidden layers - compression_ratio (float): Ratio for KV cache compression (1.0 = no compression) - num_key_value_heads (int): Number of key/value heads (for Grouped Query Attention) - max_sequence_length (int): Maximum sequence length for positional embeddings - **kwargs: Additional keyword arguments passed to parent class - """ - - # Set default for tied embeddings if not specified. - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - **kwargs, - ) - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.max_position_embeddings = max_position_embeddings - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.use_flash_attention = use_flash_attention - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.head_dim = head_dim - self.hidden_act = hidden_act - self.sliding_window = None - self.hidden_size = hidden_size - self.use_bias = use_bias - self.weight_share_add_bias = weight_share_add_bias - self.rope_theta = rope_theta - self.ignored_index = ignored_index - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.hidden_dropout_prob = hidden_dropout_prob - self.compression_ratio = compression_ratio - self.num_key_value_heads = num_key_value_heads - self.max_sequence_length = max_sequence_length - # Try to set `rope_scaling` if available, otherwise use `rope_parameters` - rope_scaling = kwargs.pop("rope_scaling", None) - self.rope_parameters = rope_scaling or rope_parameters - if self.rope_parameters is not None and self.rope_parameters["rope_type"] == "mrope": - self.rope_parameters["rope_type"] = "default" - rope_config_validation(self, ignore_keys={"mrope_section"}) - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) - class PaddleOCRVLConfig(Qwen2VLConfig): - pass + sub_configs = {"vision_config": PaddleOCRVisionConfig, "text_config": PaddleOCRTextConfig} class PaddleOCRProjector(nn.Module): @@ -737,7 +550,7 @@ def __init__(self, config: PaddleOCRVLConfig): super().__init__() self.text_config = config.text_config self.vision_config = config.vision_config - self.merge_kernel_size = (2, 2) + self.merge_kernel_size = (self.vision_config.spatial_merge_size, self.vision_config.spatial_merge_size) self.hidden_size = self.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1] @@ -778,7 +591,7 @@ class PaddleOCRRotaryEmbedding(Qwen2VLRotaryEmbedding): class PaddleOCRMLP(Ernie4_5MLP): - def __init__(self, config: PaddleOCRVLTextConfig): + def __init__(self, config: PaddleOCRTextConfig): super().__init__() @@ -798,7 +611,7 @@ class PaddleOCRRMSNorm(Ernie4_5RMSNorm): class PaddleOCRDecoderLayer(Ernie4_5DecoderLayer): - def __init__(self, config: PaddleOCRVLTextConfig, layer_idx: int): + def __init__(self, config: PaddleOCRTextConfig, layer_idx: int): super().__init__() @@ -817,17 +630,17 @@ class PaddleOCRVLPreTrainedModel(PreTrainedModel): _supports_attention_backend = True -class PaddleOCRVLTextModel(PaddleOCRVLPreTrainedModel, Ernie4_5Model): - def __init__(self, config: PaddleOCRVLTextConfig): +class PaddleOCRTextModel(PaddleOCRVLPreTrainedModel, Ernie4_5Model): + def __init__(self, config: PaddleOCRTextConfig): super().__init__(config) class PaddleOCRVisionModel(PaddleOCRVLPreTrainedModel): - config: PaddleOCRVLVisionConfig + config: PaddleOCRVisionConfig main_input_name = "pixel_values" input_modalities = "image" - def __init__(self, config: PaddleOCRVLVisionConfig): + def __init__(self, config: PaddleOCRVisionConfig): super().__init__(config) self.vision_model = PaddleOCRVisionTransformer(config) @@ -837,8 +650,8 @@ def __init__(self, config: PaddleOCRVLVisionConfig): def forward( self, - pixel_values, - cu_seqlens, + pixel_values: torch.FloatTensor, + cu_seqlens: torch.Tensor, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, ) -> BaseModelOutputWithPooling: """ @@ -858,7 +671,7 @@ def forward( class PaddleOCRVisionEmbeddings(SiglipVisionEmbeddings): - def __init__(self, config: PaddleOCRVLVisionConfig): + def __init__(self, config: PaddleOCRVisionConfig): super().__init__() def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: @@ -868,16 +681,13 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: dim = embeddings.shape[-1] - new_height = height - new_width = width - sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - size=(new_height, new_width), + size=(height, width), mode="bilinear", align_corners=False, ) @@ -894,8 +704,6 @@ def forward( Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. - position_ids (`torch.LongTensor` of shape `sequence_length`): - The position ids of the image. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ @@ -906,9 +714,6 @@ def forward( embeddings = patch_embeds.flatten(-2).squeeze(-1) embeddings = embeddings.reshape(batch_size, squence_len, -1) - assert batch_size == 1, ( - f"Batch size must be 1, but received {batch_size}. This model only processes one image at a time." - ) start = 0 embeddings = embeddings.squeeze(0) tmp_embeddings = [] @@ -926,29 +731,22 @@ def forward( class PaddleOCRVisionAttention(VideoLlama3VisionAttention): - def __init__(self, config: PaddleOCRVLVisionConfig): + def __init__(self, config: PaddleOCRVisionConfig): super().__init__() class PaddleOCRVisionMLP(SiglipMLP): - def __init__(self, config: PaddleOCRVLVisionConfig): + def __init__(self, config: PaddleOCRVisionConfig): super().__init__() -class PaddleOCRVisionMultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead): - def __init__(self, config: PaddleOCRVLVisionConfig): - super().__init__() - - self.mlp = PaddleOCRVisionMLP(config) - - class PaddleOCRVisionEncoderLayer(VideoLlama3VisionEncoderLayer): - def __init__(self, config: PaddleOCRVLVisionConfig): + def __init__(self, config: PaddleOCRVisionConfig): super().__init__() class PaddleOCRVisionEncoder(VideoLlama3VisionEncoder): - def __init__(self, config: PaddleOCRVLVisionConfig): + def __init__(self, config: PaddleOCRVisionConfig): super().__init__() embed_dim = config.hidden_size num_heads = config.num_attention_heads @@ -957,13 +755,17 @@ def __init__(self, config: PaddleOCRVLVisionConfig): def forward( self, - inputs_embeds, - cu_seqlens, + inputs_embeds: torch.FloatTensor, + cu_seqlens: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, ) -> BaseModelOutput: """ Args: + inputs_embeds (`torch.FloatTensor` of shape `(sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1009,7 +811,7 @@ def forward( class PaddleOCRVisionTransformer(PaddleOCRVLPreTrainedModel): - def __init__(self, config: PaddleOCRVLVisionConfig): + def __init__(self, config: PaddleOCRVisionConfig): super().__init__(config) self.config = config embed_dim = config.hidden_size @@ -1017,13 +819,10 @@ def __init__(self, config: PaddleOCRVLVisionConfig): self.embeddings = PaddleOCRVisionEmbeddings(config) self.encoder = PaddleOCRVisionEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head - if self.use_head: - self.head = PaddleOCRVisionMultiheadAttentionPoolingHead(config) def forward( self, - pixel_values, + pixel_values: torch.FloatTensor, cu_seqlens: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, @@ -1036,8 +835,6 @@ def forward( The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - position_ids (`torch.LongTensor` of shape `sequence_length`): - The position ids of the image. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ @@ -1071,13 +868,13 @@ class PaddleOCRVLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast): class PaddleOCRVLModel(Qwen2VLModel): _checkpoint_conversion_mapping = {"^model": "language_model"} - _keys_to_ignore_on_load_unexpected = ["packing_position_embedding"] + _keys_to_ignore_on_load_unexpected = ["packing_position_embedding", "vision_model.head"] def __init__(self, config: PaddleOCRVLConfig): super().__init__(config) - self.visual = PaddleOCRVisionModel(config.vision_config) + self.visual = PaddleOCRVisionModel._from_config(config.vision_config) self.projector = PaddleOCRProjector(config) - self.language_model = PaddleOCRVLTextModel(config.text_config) + self.language_model = PaddleOCRTextModel._from_config(config.text_config) self.rope_deltas = None self.post_init() @@ -1114,25 +911,32 @@ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Op cu_seqlens=cu_seqlens, ) image_embeds = vision_outputs.last_hidden_state + image_embeds = self.projector(image_embeds, image_grid_thw) return image_embeds def get_placeholder_mask( - self, - input_ids: torch.LongTensor, - inputs_embeds: torch.FloatTensor, - image_features: Optional[torch.FloatTensor] = None, + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor ): - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - n_image_features = image_features.shape[0] - if n_image_tokens != n_image_features: + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - mask = input_ids == self.config.image_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - image_mask = mask_expanded.to(inputs_embeds.device) - return image_mask + return special_image_mask @can_return_tuple def forward( @@ -1162,8 +966,9 @@ def forward( inputs_embeds = self.language_model.embed_tokens(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw) - image_embeds = self.projector(image_embeds, image_grid_thw) + image_embeds = self.get_image_features(pixel_values, image_grid_thw).to( + inputs_embeds.device, inputs_embeds.dtype + ) image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) @@ -1215,7 +1020,7 @@ def forward( rope_deltas=self.rope_deltas, ) - return output if return_dict else output.to_tuple() + return output class PaddleOCRVLForConditionalGeneration(Qwen2VLForConditionalGeneration): @@ -1224,7 +1029,7 @@ class PaddleOCRVLForConditionalGeneration(Qwen2VLForConditionalGeneration): "^mlp_AR": "model.projector", r"^model(?!(\.visual|\.projector))": "model.language_model", } - _keys_to_ignore_on_load_unexpected = ["packing_position_embedding"] + _keys_to_ignore_on_load_unexpected = ["packing_position_embedding", "vision_model.head"] @can_return_tuple @auto_docstring @@ -1253,8 +1058,6 @@ def forward( (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): - The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. @@ -1340,8 +1143,6 @@ def prepare_inputs_for_generation( pixel_values=None, pixel_values_videos=None, image_grid_thw=None, - video_grid_thw=None, - second_per_grid_ts=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -1356,8 +1157,6 @@ def prepare_inputs_for_generation( pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, use_cache=use_cache, **kwargs, ) @@ -1373,7 +1172,10 @@ def prepare_inputs_for_generation( __all__ = [ "PaddleOCRVLForConditionalGeneration", "PaddleOCRVLConfig", - "PaddleOCRVLTextConfig", + "PaddleOCRTextModel", + "PaddleOCRVisionModel", + "PaddleOCRVisionConfig", + "PaddleOCRTextConfig", "PaddleOCRVLImageProcessor", "PaddleOCRVLImageProcessorFast", "PaddleOCRVLProcessor", From f1cd9acce1e61a79e73b5f61cf5210db42f93069 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Thu, 4 Dec 2025 18:16:13 +0800 Subject: [PATCH 04/19] update --- .../configuration_paddleocr_vl.py | 62 +++---------- .../image_processing_paddleocr_vl.py | 22 +++-- .../image_processing_paddleocr_vl_fast.py | 29 +++--- .../paddleocr_vl/modeling_paddleocr_vl.py | 69 ++++---------- .../paddleocr_vl/modular_paddleocr_vl.py | 93 +++++++++---------- 5 files changed, 107 insertions(+), 168 deletions(-) diff --git a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py index b8edb01825a4..35e9e9eb935f 100644 --- a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py @@ -23,12 +23,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial +import inspect from typing import Optional from ...configuration_utils import PreTrainedConfig -from ...modeling_rope_utils import RopeParameters, standardize_rope_params -from ...modeling_rope_utils import rope_config_validation as _rope_config_validation +from ...modeling_rope_utils import RopeParameters class PaddleOCRVisionConfig(PreTrainedConfig): @@ -114,9 +113,6 @@ def __init__( self.temporal_patch_size = temporal_patch_size -rope_config_validation = partial(_rope_config_validation, ignore_keys={"mrope_section"}) - - class PaddleOCRTextConfig(PreTrainedConfig): r""" This is the configuration class to store the configuration of a [`PaddleOCRTextModel`]. It is used to instantiate an Ernie 4.5 @@ -167,7 +163,7 @@ class PaddleOCRTextConfig(PreTrainedConfig): tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether to tie weight embeddings rope_parameters (`RopeParameters`, *optional*): - Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE with longer `max_position_embeddings`. use_bias (`bool`, *optional*, defaults to `False`): @@ -190,6 +186,7 @@ class PaddleOCRTextConfig(PreTrainedConfig): model_type = "paddleocr_vl_text" keys_to_ignore_at_inference = ["past_key_values"] + default_theta = 500000.0 # Default tensor parallel plan for base model `PaddleOCRTextModel` base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", @@ -228,6 +225,7 @@ def __init__( head_dim: Optional[int] = 128, **kwargs, ): + kwargs["ignore_keys_at_rope_validation"] = {"mrope_section"} self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size @@ -246,14 +244,7 @@ def __init__( self.use_cache = use_cache self.use_bias = use_bias self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads - # Try to set `rope_scaling` if available, otherwise use `rope_parameters` - rope_scaling = kwargs.pop("rope_scaling", None) - self.rope_parameters = rope_scaling or rope_parameters - - # Validate the correctness of rotary position embeddings parameters - rope_theta = kwargs.get("rope_theta", 500000.0) - standardize_rope_params(self, rope_theta=rope_theta) - rope_config_validation(self) + self.rope_parameters = rope_parameters super().__init__( pad_token_id=pad_token_id, @@ -316,11 +307,6 @@ def __init__( vision_end_token_id=151653, **kwargs, ): - # We need to init super() here so that it does not reset values - # that are in text config to the BaseClass defaults. The Base - # config has many text related defaults and not all defaults are same as for `PaddleOCRVLTextConfig` - super().__init__(**kwargs) - if isinstance(vision_config, dict): self.vision_config = self.sub_configs["vision_config"](**vision_config) elif vision_config is None: @@ -329,39 +315,21 @@ def __init__( if isinstance(text_config, dict): self.text_config = self.sub_configs["text_config"](**text_config) elif text_config is None: - # For BC use all kwargs to init `TextConfig` - self.text_config = self.sub_configs["text_config"](**kwargs) + # Hub configs are saved as flat dicts so we pop some of kwargs to init `TextConfig` + text_params = inspect.signature(self.sub_configs["text_config"].__init__).parameters.keys() + text_params = list(text_params) + ["rope_scaling", "rope_theta"] + text_config = {key: kwargs.pop(key) for key in text_params if key in kwargs} + text_config["dtype"] = kwargs.get("torch_dtype", kwargs.get("dtype")) # don't pop the dtype + self.text_config = self.sub_configs["text_config"](**text_config) self.image_token_id = image_token_id self.video_token_id = video_token_id self.vision_start_token_id = vision_start_token_id self.vision_end_token_id = vision_end_token_id - # Attention implementation to use. It sets it recursively on sub-configs so we call it again in the end - self._attn_implementation = kwargs.pop("attn_implementation", None) - - def __setattr__(self, key, value): - if ( - (text_config := super().__getattribute__("__dict__").get("text_config")) is not None - and key not in ["_name_or_path", "model_type", "dtype", "_attn_implementation_internal"] - and key in text_config.__dict__ - ): - setattr(text_config, key, value) - else: - super().__setattr__(key, value) - - def __getattribute__(self, key): - if "text_config" in super().__getattribute__("__dict__") and key not in [ - "_name_or_path", - "model_type", - "dtype", - "_attn_implementation_internal", - ]: - text_config = super().__getattribute__("text_config") - if key in text_config.__dict__: - return getattr(text_config, key) - - return super().__getattribute__(key) + # FIXME: arthur/cyril - tying has to be used from the text config + kwargs["tie_word_embeddings"] = self.text_config.tie_word_embeddings + super().__init__(**kwargs) __all__ = ["PaddleOCRVLConfig", "PaddleOCRVisionConfig", "PaddleOCRTextConfig"] diff --git a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py index 2f7c00a5a9df..4bda9d9d5faf 100644 --- a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py @@ -228,6 +228,7 @@ def _preprocess( - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ images = make_list_of_images(images) + images = self.fetch_images(images) if do_convert_rgb: images = [convert_to_rgb(image) for image in images] @@ -253,7 +254,7 @@ def _preprocess( resized_height, resized_width = smart_resize( height, width, - factor=self.patch_size * self.merge_size, + factor=patch_size * merge_size, min_pixels=self.min_pixels, max_pixels=self.max_pixels, ) @@ -281,26 +282,27 @@ def _preprocess( if data_format == ChannelDimension.LAST: patches = patches.transpose(0, 3, 1, 2) if patches.shape[0] == 1: - patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1)) + patches = np.tile(patches, (temporal_patch_size, 1, 1, 1)) channel = patches.shape[1] - grid_t = patches.shape[0] // self.temporal_patch_size + grid_t = patches.shape[0] // temporal_patch_size grid_h, grid_w = ( - resized_height // self.patch_size, - resized_width // self.patch_size, + resized_height // patch_size, + resized_width // patch_size, ) patches = patches.reshape( grid_t, - self.temporal_patch_size, + temporal_patch_size, channel, grid_h, - self.patch_size, + patch_size, grid_w, - self.patch_size, + patch_size, ) patches = patches.transpose(0, 3, 5, 2, 1, 4, 6) - assert self.temporal_patch_size == 1 - flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel, self.patch_size, self.patch_size) + if temporal_patch_size != 1: + raise ValueError("temporal_patch_size must be 1!") + flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel, patch_size, patch_size) return flatten_patches, (grid_t, grid_h, grid_w) def preprocess( diff --git a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py index 48d47d77c2d4..f7b0aaeaba67 100644 --- a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +++ b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py @@ -117,8 +117,15 @@ def _preprocess( image_std: Optional[Union[float, list[float]]], disable_grouping: Optional[bool], return_tensors: Optional[Union[str, TensorType]], + patch_size: Optional[int] = None, + temporal_patch_size: Optional[int] = None, + merge_size: Optional[int] = None, **kwargs, ): + patch_size = patch_size if patch_size is not None else self.patch_size + temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self.temporal_patch_size + merge_size = merge_size if merge_size is not None else self.merge_size + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) resized_images_grouped = {} for shape, stacked_images in grouped_images.items(): @@ -127,7 +134,7 @@ def _preprocess( resized_height, resized_width = smart_resize( height, width, - factor=self.patch_size * self.merge_size, + factor=patch_size * merge_size, min_pixels=self.min_pixels, max_pixels=self.max_pixels, ) @@ -154,30 +161,28 @@ def _preprocess( if patches.ndim == 4: # add a temporal dimension if we have images patches = patches.unsqueeze(1) - if patches.shape[1] % self.temporal_patch_size != 0: - repeats = patches[:, -1:].repeat(1, self.temporal_patch_size - 1, 1, 1, 1) + if patches.shape[1] % temporal_patch_size != 0: + repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1) patches = torch.cat([patches, repeats], dim=1) batch_size, grid_t, channel = patches.shape[:3] - grid_t = grid_t // self.temporal_patch_size + grid_t = grid_t // temporal_patch_size grid_h, grid_w = ( - resized_height // self.patch_size, - resized_width // self.patch_size, + resized_height // patch_size, + resized_width // patch_size, ) patches = patches.view( batch_size, grid_t, - self.temporal_patch_size, + temporal_patch_size, channel, grid_h, - self.patch_size, + patch_size, grid_w, - self.patch_size, + patch_size, ) patches = patches.permute(0, 1, 4, 6, 3, 2, 5, 7) - flatten_patches = patches.reshape( - batch_size, grid_t * grid_h * grid_w, channel, self.patch_size, self.patch_size - ) + flatten_patches = patches.reshape(batch_size, grid_t * grid_h * grid_w, channel, patch_size, patch_size) processed_images_grouped[shape] = flatten_patches processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index f6c214158a12..7764ad7da737 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -52,16 +52,14 @@ class PaddleOCRProjector(nn.Module): def __init__(self, config: PaddleOCRVLConfig): super().__init__() - self.text_config = config.text_config - self.vision_config = config.vision_config - self.merge_kernel_size = (self.vision_config.spatial_merge_size, self.vision_config.spatial_merge_size) + self.merge_kernel_size = (config.vision_config.spatial_merge_size, config.vision_config.spatial_merge_size) - self.hidden_size = self.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1] + hidden_size = config.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1] - self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05) - self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-05) + self.linear_1 = nn.Linear(hidden_size, hidden_size, bias=True) self.act = GELUActivation() - self.linear_2 = nn.Linear(self.hidden_size, self.text_config.hidden_size, bias=True) + self.linear_2 = nn.Linear(hidden_size, config.text_config.hidden_size, bias=True) def forward(self, image_features: torch.Tensor, image_grid_thw: torch.Tensor) -> torch.Tensor: image_features_chunks = image_features.split(image_grid_thw.prod(dim=1).tolist(), dim=0) @@ -459,7 +457,7 @@ def __init__(self, config: PaddleOCRTextConfig): # Initialize weights and apply final processing self.post_init() - @check_model_inputs() + @check_model_inputs @auto_docstring def forward( self, @@ -483,8 +481,8 @@ def forward( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position: torch.Tensor = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + cache_position: torch.Tensor = ( + torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens ) if position_ids is None: @@ -509,6 +507,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) @@ -998,7 +997,7 @@ class PaddleOCRVLCausalLMOutputWithPast(ModelOutput): @auto_docstring class PaddleOCRVLModel(PaddleOCRVLPreTrainedModel): - base_model_prefix = "" + base_model_prefix = "model" _checkpoint_conversion_mapping = {"^model": "language_model"} # Reference: fix gemma3 grad acc #37208 accepts_loss_kwargs = False @@ -1020,12 +1019,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.embed_tokens = value - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_rope_index( self, input_ids: Optional[torch.LongTensor] = None, @@ -1280,14 +1273,9 @@ def forward( image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): - # calculate RoPE index once per generation in the pre-fill stage only - if ( - (cache_position is not None and cache_position[0] == 0) - or self.rope_deltas is None - or (past_key_values is None or past_key_values.get_seq_length() == 0) - ): + if position_ids is None: + past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length() + if self.rope_deltas is None or past_key_values_length == 0: position_ids, rope_deltas = self.get_rope_index( input_ids=input_ids, image_grid_thw=image_grid_thw, @@ -1297,17 +1285,11 @@ def forward( # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape - delta = ( - (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) - if cache_position is not None - else 0 - ) position_ids = torch.arange(seq_length, device=inputs_embeds.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - if cache_position is not None: # otherwise `deltas` is an int `0` - delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) - position_ids = position_ids.add(delta) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1) + delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device) + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids + delta.to(position_ids.device) outputs = self.language_model( input_ids=None, @@ -1337,7 +1319,7 @@ class PaddleOCRVLForConditionalGeneration(PaddleOCRVLPreTrainedModel, Generation "^mlp_AR": "model.projector", r"^model(?!(\.visual|\.projector))": "model.language_model", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} _keys_to_ignore_on_load_unexpected = ["packing_position_embedding", "vision_model.head"] def __init__(self, config): @@ -1353,12 +1335,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None ): @@ -1367,15 +1343,6 @@ def get_video_features( def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): return self.model.get_image_features(pixel_values, image_grid_thw) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def visual(self): - return self.model.visual - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 8df7bc64a23d..11763da198fc 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -18,7 +18,6 @@ # limitations under the License. import math -from functools import partial from typing import Optional, Union import numpy as np @@ -46,7 +45,6 @@ ) from ...masking_utils import create_bidirectional_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from ...modeling_rope_utils import rope_config_validation as _rope_config_validation from ...modeling_utils import PreTrainedModel from ...models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor from ...processing_utils import ( @@ -88,7 +86,6 @@ logger = logging.get_logger(__name__) -rope_config_validation = partial(_rope_config_validation, ignore_keys={"mrope_section"}) def smart_resize( @@ -220,6 +217,7 @@ def _preprocess( - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ images = make_list_of_images(images) + images = self.fetch_images(images) if do_convert_rgb: images = [convert_to_rgb(image) for image in images] @@ -245,7 +243,7 @@ def _preprocess( resized_height, resized_width = smart_resize( height, width, - factor=self.patch_size * self.merge_size, + factor=patch_size * merge_size, min_pixels=self.min_pixels, max_pixels=self.max_pixels, ) @@ -273,26 +271,27 @@ def _preprocess( if data_format == ChannelDimension.LAST: patches = patches.transpose(0, 3, 1, 2) if patches.shape[0] == 1: - patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1)) + patches = np.tile(patches, (temporal_patch_size, 1, 1, 1)) channel = patches.shape[1] - grid_t = patches.shape[0] // self.temporal_patch_size + grid_t = patches.shape[0] // temporal_patch_size grid_h, grid_w = ( - resized_height // self.patch_size, - resized_width // self.patch_size, + resized_height // patch_size, + resized_width // patch_size, ) patches = patches.reshape( grid_t, - self.temporal_patch_size, + temporal_patch_size, channel, grid_h, - self.patch_size, + patch_size, grid_w, - self.patch_size, + patch_size, ) patches = patches.transpose(0, 3, 5, 2, 1, 4, 6) - assert self.temporal_patch_size == 1 - flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel, self.patch_size, self.patch_size) + if temporal_patch_size != 1: + raise ValueError("temporal_patch_size must be 1!") + flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel, patch_size, patch_size) return flatten_patches, (grid_t, grid_h, grid_w) @@ -346,8 +345,15 @@ def _preprocess( image_std: Optional[Union[float, list[float]]], disable_grouping: Optional[bool], return_tensors: Optional[Union[str, TensorType]], + patch_size: Optional[int] = None, + temporal_patch_size: Optional[int] = None, + merge_size: Optional[int] = None, **kwargs, ): + patch_size = patch_size if patch_size is not None else self.patch_size + temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self.temporal_patch_size + merge_size = merge_size if merge_size is not None else self.merge_size + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) resized_images_grouped = {} for shape, stacked_images in grouped_images.items(): @@ -356,7 +362,7 @@ def _preprocess( resized_height, resized_width = smart_resize( height, width, - factor=self.patch_size * self.merge_size, + factor=patch_size * merge_size, min_pixels=self.min_pixels, max_pixels=self.max_pixels, ) @@ -383,29 +389,29 @@ def _preprocess( if patches.ndim == 4: # add a temporal dimension if we have images patches = patches.unsqueeze(1) - if patches.shape[1] % self.temporal_patch_size != 0: - repeats = patches[:, -1:].repeat(1, self.temporal_patch_size - 1, 1, 1, 1) + if patches.shape[1] % temporal_patch_size != 0: + repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1) patches = torch.cat([patches, repeats], dim=1) batch_size, grid_t, channel = patches.shape[:3] - grid_t = grid_t // self.temporal_patch_size + grid_t = grid_t // temporal_patch_size grid_h, grid_w = ( - resized_height // self.patch_size, - resized_width // self.patch_size, + resized_height // patch_size, + resized_width // patch_size, ) patches = patches.view( batch_size, grid_t, - self.temporal_patch_size, + temporal_patch_size, channel, grid_h, - self.patch_size, + patch_size, grid_w, - self.patch_size, + patch_size, ) patches = patches.permute(0, 1, 4, 6, 3, 2, 5, 7) flatten_patches = patches.reshape( - batch_size, grid_t * grid_h * grid_w, channel, self.patch_size, self.patch_size + batch_size, grid_t * grid_h * grid_w, channel, patch_size, patch_size ) processed_images_grouped[shape] = flatten_patches @@ -540,6 +546,10 @@ def __init__( class PaddleOCRTextConfig(Ernie4_5Config): model_type = "paddleocr_vl_text" + def __init__(self, **super_kwargs): + kwargs["ignore_keys_at_rope_validation"] = {"mrope_section"} + super().__init__() + class PaddleOCRVLConfig(Qwen2VLConfig): sub_configs = {"vision_config": PaddleOCRVisionConfig, "text_config": PaddleOCRTextConfig} @@ -548,16 +558,14 @@ class PaddleOCRVLConfig(Qwen2VLConfig): class PaddleOCRProjector(nn.Module): def __init__(self, config: PaddleOCRVLConfig): super().__init__() - self.text_config = config.text_config - self.vision_config = config.vision_config - self.merge_kernel_size = (self.vision_config.spatial_merge_size, self.vision_config.spatial_merge_size) + self.merge_kernel_size = (config.vision_config.spatial_merge_size, config.vision_config.spatial_merge_size) - self.hidden_size = self.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1] + hidden_size = config.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1] - self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05) - self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-05) + self.linear_1 = nn.Linear(hidden_size, hidden_size, bias=True) self.act = GELUActivation() - self.linear_2 = nn.Linear(self.hidden_size, self.text_config.hidden_size, bias=True) + self.linear_2 = nn.Linear(hidden_size, config.text_config.hidden_size, bias=True) def forward(self, image_features: torch.Tensor, image_grid_thw: torch.Tensor) -> torch.Tensor: image_features_chunks = image_features.split(image_grid_thw.prod(dim=1).tolist(), dim=0) @@ -972,14 +980,9 @@ def forward( image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): - # calculate RoPE index once per generation in the pre-fill stage only - if ( - (cache_position is not None and cache_position[0] == 0) - or self.rope_deltas is None - or (past_key_values is None or past_key_values.get_seq_length() == 0) - ): + if position_ids is None: + past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length() + if self.rope_deltas is None or past_key_values_length == 0: position_ids, rope_deltas = self.get_rope_index( input_ids=input_ids, image_grid_thw=image_grid_thw, @@ -989,17 +992,11 @@ def forward( # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape - delta = ( - (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) - if cache_position is not None - else 0 - ) position_ids = torch.arange(seq_length, device=inputs_embeds.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - if cache_position is not None: # otherwise `deltas` is an int `0` - delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) - position_ids = position_ids.add(delta) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1) + delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device) + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids + delta.to(position_ids.device) outputs = self.language_model( input_ids=None, From 1839fab5f9ad071c5270dd5b747e1ae83140dc96 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Fri, 5 Dec 2025 16:38:54 +0800 Subject: [PATCH 05/19] fix unresolved problems --- src/transformers/conversion_mapping.py | 1 + .../models/paddleocr_vl/configuration_paddleocr_vl.py | 1 - .../paddleocr_vl/image_processing_paddleocr_vl.py | 2 +- .../models/paddleocr_vl/modeling_paddleocr_vl.py | 3 ++- .../models/paddleocr_vl/modular_paddleocr_vl.py | 11 ++++------- .../models/paddleocr_vl/processing_paddleocr_vl.py | 2 +- 6 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 5968bd08d406..fe8f87d2d972 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -213,6 +213,7 @@ def get_checkpoint_conversion_mapping(model_type): "sam3", "sam3_tracker", "sam3_tracker_video", + "paddleocrvl" ] diff --git a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py index 35e9e9eb935f..cc8ff58dc97f 100644 --- a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py @@ -225,7 +225,6 @@ def __init__( head_dim: Optional[int] = 128, **kwargs, ): - kwargs["ignore_keys_at_rope_validation"] = {"mrope_section"} self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size diff --git a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py index 4bda9d9d5faf..788fdcee982a 100644 --- a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py @@ -301,7 +301,7 @@ def _preprocess( ) patches = patches.transpose(0, 3, 5, 2, 1, 4, 6) if temporal_patch_size != 1: - raise ValueError("temporal_patch_size must be 1!") + raise ValueError(f"temporal_patch_size must be 1!, but got {temporal_patch_size}!") flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel, patch_size, patch_size) return flatten_patches, (grid_t, grid_h, grid_w) diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 7764ad7da737..a96494770f49 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -74,7 +74,8 @@ def forward(self, image_features: torch.Tensor, image_grid_thw: torch.Tensor) -> w_block = w // m2 image_feature = image_feature.reshape(t, h_block, m1, w_block, m2, d) - image_feature = image_feature.reshape((t * h_block * w_block), (m1 * m2 * d)) + image_feature = image_feature.transpose(2, 3) + image_feature = image_feature.reshape(t * h_block * w_block, m1 * m2 * d) hidden_states = self.linear_1(image_feature) hidden_states = self.act(hidden_states) diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 11763da198fc..5f97b7f419f2 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -290,7 +290,7 @@ def _preprocess( ) patches = patches.transpose(0, 3, 5, 2, 1, 4, 6) if temporal_patch_size != 1: - raise ValueError("temporal_patch_size must be 1!") + raise ValueError(f"temporal_patch_size must be 1!, but got {temporal_patch_size}!") flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel, patch_size, patch_size) return flatten_patches, (grid_t, grid_h, grid_w) @@ -452,7 +452,7 @@ class PaddleOCRVLProcessor(ProcessorMixin): tokenizer_class = "AutoTokenizer" def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): - self.image_token = "<|IMAGE_PLACEHOLDER|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.image_token = tokenizer.image_token super().__init__(image_processor, tokenizer, chat_template=chat_template) def __call__( @@ -546,10 +546,6 @@ def __init__( class PaddleOCRTextConfig(Ernie4_5Config): model_type = "paddleocr_vl_text" - def __init__(self, **super_kwargs): - kwargs["ignore_keys_at_rope_validation"] = {"mrope_section"} - super().__init__() - class PaddleOCRVLConfig(Qwen2VLConfig): sub_configs = {"vision_config": PaddleOCRVisionConfig, "text_config": PaddleOCRTextConfig} @@ -580,7 +576,8 @@ def forward(self, image_features: torch.Tensor, image_grid_thw: torch.Tensor) -> w_block = w // m2 image_feature = image_feature.reshape(t, h_block, m1, w_block, m2, d) - image_feature = image_feature.reshape((t * h_block * w_block), (m1 * m2 * d)) + image_feature = image_feature.transpose(2, 3) + image_feature = image_feature.reshape(t * h_block * w_block, m1 * m2 * d) hidden_states = self.linear_1(image_feature) hidden_states = self.act(hidden_states) diff --git a/src/transformers/models/paddleocr_vl/processing_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/processing_paddleocr_vl.py index 4c89f85ff429..e7bd822d3feb 100644 --- a/src/transformers/models/paddleocr_vl/processing_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/processing_paddleocr_vl.py @@ -56,7 +56,7 @@ class PaddleOCRVLProcessor(ProcessorMixin): tokenizer_class = "AutoTokenizer" def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): - self.image_token = "<|IMAGE_PLACEHOLDER|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.image_token = tokenizer.image_token super().__init__(image_processor, tokenizer, chat_template=chat_template) def __call__( From a7468c58be77455ce3979921face379f14374ee4 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Mon, 8 Dec 2025 14:22:39 +0800 Subject: [PATCH 06/19] fix how position_ids work with flash_attn_2 --- .../paddleocr_vl/modeling_paddleocr_vl.py | 47 +++++++- .../paddleocr_vl/modular_paddleocr_vl.py | 113 +++++++++++------- 2 files changed, 115 insertions(+), 45 deletions(-) diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index a96494770f49..80e691255425 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -487,7 +487,15 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = None causal_mask = create_causal_mask( config=self.config, @@ -495,7 +503,7 @@ def forward( attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - position_ids=position_ids, + position_ids=text_position_ids, ) hidden_states = inputs_embeds @@ -506,7 +514,7 @@ def forward( hidden_states, attention_mask=causal_mask, position_embeddings=position_embeddings, - position_ids=position_ids, + position_ids=text_position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, @@ -1456,6 +1464,7 @@ def prepare_inputs_for_generation( pixel_values=None, pixel_values_videos=None, image_grid_thw=None, + video_grid_thw=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -1470,14 +1479,42 @@ def prepare_inputs_for_generation( pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, use_cache=use_cache, **kwargs, ) - model_inputs["position_ids"] = None + # Qwen2-VL position_ids are prepareed with rope_deltas in forward + if position_ids is None: + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None: + vision_positions, rope_deltas = self.model.get_rope_index( + model_inputs.get("input_ids", None), + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + ) + self.model.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + elif "position_ids" in model_inputs: + batch_size, seq_length = model_inputs["position_ids"].shape + device = model_inputs["position_ids"].device + position_ids = torch.arange(seq_length, device=device) + position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1) + delta = cache_position[0] + self.model.rope_deltas + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + vision_positions = position_ids + delta.expand_as(position_ids) + + # Concatenate "text + vision" positions into [4, bs, seq-len] + text_positions = model_inputs["position_ids"][None, ...] + model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0) - if cache_position[0] != 0: + if model_inputs["cache_position"][0] != 0: model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None return model_inputs diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 5f97b7f419f2..b51b96f26a43 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -26,7 +26,7 @@ from torch import nn from ...activations import GELUActivation -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...image_processing_utils import BatchFeature from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format @@ -43,8 +43,8 @@ make_list_of_images, to_numpy_array, ) -from ...masking_utils import create_bidirectional_mask -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...masking_utils import create_bidirectional_mask, create_causal_mask +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor from ...processing_utils import ( @@ -54,6 +54,7 @@ ) from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import TensorType, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int +from ...utils.generic import check_model_inputs from ..ernie4_5.configuration_ernie4_5 import Ernie4_5Config from ..ernie4_5.modeling_ernie4_5 import ( Ernie4_5DecoderLayer, @@ -639,6 +640,75 @@ class PaddleOCRTextModel(PaddleOCRVLPreTrainedModel, Ernie4_5Model): def __init__(self, config: PaddleOCRTextConfig): super().__init__(config) + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: torch.Tensor = ( + torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + ) + + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = None + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_embeddings=position_embeddings, + position_ids=text_position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + class PaddleOCRVisionModel(PaddleOCRVLPreTrainedModel): config: PaddleOCRVisionConfig @@ -1125,43 +1195,6 @@ def forward( rope_deltas=outputs.rope_deltas, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - pixel_values=None, - pixel_values_videos=None, - image_grid_thw=None, - **kwargs, - ): - # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - - model_inputs = super().prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - cache_position=cache_position, - position_ids=position_ids, - pixel_values=pixel_values, - pixel_values_videos=pixel_values_videos, - image_grid_thw=image_grid_thw, - use_cache=use_cache, - **kwargs, - ) - - model_inputs["position_ids"] = None - - if cache_position[0] != 0: - model_inputs["pixel_values"] = None - - return model_inputs - __all__ = [ "PaddleOCRVLForConditionalGeneration", From af6d108ea95daeb16b3abf5215793b5678cc5372 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Wed, 10 Dec 2025 16:08:00 +0800 Subject: [PATCH 07/19] add tests and fix code --- .../image_processing_paddleocr_vl.py | 4 ++-- .../image_processing_paddleocr_vl_fast.py | 4 ++-- .../paddleocr_vl/modeling_paddleocr_vl.py | 14 +++++++----- .../paddleocr_vl/modular_paddleocr_vl.py | 22 +++++++++++-------- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py index 788fdcee982a..9f9bc1417959 100644 --- a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py @@ -255,8 +255,8 @@ def _preprocess( height, width, factor=patch_size * merge_size, - min_pixels=self.min_pixels, - max_pixels=self.max_pixels, + min_pixels=size["shortest_edge"], + max_pixels=size["longest_edge"], ) image = resize( image, diff --git a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py index f7b0aaeaba67..dc846c7165ea 100644 --- a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +++ b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py @@ -135,8 +135,8 @@ def _preprocess( height, width, factor=patch_size * merge_size, - min_pixels=self.min_pixels, - max_pixels=self.max_pixels, + min_pixels=size["shortest_edge"], + max_pixels=size["longest_edge"], ) stacked_images = self.resize( image=stacked_images, diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 80e691255425..ea6f2fb91a3e 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -439,6 +439,11 @@ class PaddleOCRVLPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": PaddleOCRDecoderLayer, + "attentions": PaddleOCRAttention, + } + @auto_docstring class PaddleOCRTextModel(PaddleOCRVLPreTrainedModel): @@ -1308,6 +1313,7 @@ def forward( inputs_embeds=inputs_embeds, use_cache=use_cache, return_dict=return_dict, + cache_position=cache_position, **kwargs, ) @@ -1363,8 +1369,6 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, @@ -1428,14 +1432,14 @@ def forward( use_cache=use_cache, return_dict=True, pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + rope_deltas=rope_deltas, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state - logits = self.lm_head(hidden_states) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index b51b96f26a43..01db11c4e2b9 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -245,8 +245,8 @@ def _preprocess( height, width, factor=patch_size * merge_size, - min_pixels=self.min_pixels, - max_pixels=self.max_pixels, + min_pixels=size["shortest_edge"], + max_pixels=size["longest_edge"], ) image = resize( image, @@ -364,8 +364,8 @@ def _preprocess( height, width, factor=patch_size * merge_size, - min_pixels=self.min_pixels, - max_pixels=self.max_pixels, + min_pixels=size["shortest_edge"], + max_pixels=size["longest_edge"], ) stacked_images = self.resize( image=stacked_images, @@ -635,6 +635,11 @@ class PaddleOCRVLPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": PaddleOCRDecoderLayer, + "attentions": PaddleOCRAttention, + } + class PaddleOCRTextModel(PaddleOCRVLPreTrainedModel, Ernie4_5Model): def __init__(self, config: PaddleOCRTextConfig): @@ -1073,6 +1078,7 @@ def forward( inputs_embeds=inputs_embeds, use_cache=use_cache, return_dict=return_dict, + cache_position=cache_position, **kwargs, ) @@ -1106,8 +1112,6 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, @@ -1171,14 +1175,14 @@ def forward( use_cache=use_cache, return_dict=True, pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + rope_deltas=rope_deltas, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state - logits = self.lm_head(hidden_states) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: From daf3dfca4fd858fc62f0147683db99174230787a Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Wed, 10 Dec 2025 16:16:52 +0800 Subject: [PATCH 08/19] add model_doc --- docs/source/en/model_doc/paddleocr_vl.md | 180 +++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 docs/source/en/model_doc/paddleocr_vl.md diff --git a/docs/source/en/model_doc/paddleocr_vl.md b/docs/source/en/model_doc/paddleocr_vl.md new file mode 100644 index 000000000000..5d315c5407fe --- /dev/null +++ b/docs/source/en/model_doc/paddleocr_vl.md @@ -0,0 +1,180 @@ +*This model was released on 2025 and added to Hugging Face Transformers on 2025.12.10* + +# PaddleOCRVL + +
+PyTorch +FlashAttention +SDPA +
+ +## Overview + +**🔥 Official Website**: [Baidu AI Studio](https://aistudio.baidu.com/paddleocr) | **📝 arXiv**: [Technical Report](https://arxiv.org/pdf/2510.14528) + +**PaddleOCR-VL** is a SOTA and resource-efficient model tailored for document parsing. Its core component is PaddleOCR-VL-0.9B, a compact yet powerful vision-language model (VLM) that integrates a NaViT-style dynamic resolution visual encoder with the ERNIE-4.5-0.3B language model to enable accurate element recognition. This innovative model efficiently supports 109 languages and excels in recognizing complex elements (e.g., text, tables, formulas, and charts), while maintaining minimal resource consumption. Through comprehensive evaluations on widely used public benchmarks and in-house benchmarks, PaddleOCR-VL achieves SOTA performance in both page-level document parsing and element-level recognition. It significantly outperforms existing solutions, exhibits strong competitiveness against top-tier VLMs, and delivers fast inference speeds. These strengths make it highly suitable for practical deployment in real-world scenarios. + +
+ +
+ +### **Core Features** + +1. **Compact yet Powerful VLM Architecture:** We present a novel vision-language model that is specifically designed for resource-efficient inference, achieving outstanding performance in element recognition. By integrating a NaViT-style dynamic high-resolution visual encoder with the lightweight ERNIE-4.5-0.3B language model, we significantly enhance the model’s recognition capabilities and decoding efficiency. This integration maintains high accuracy while reducing computational demands, making it well-suited for efficient and practical document processing applications. + +2. **SOTA Performance on Document Parsing:** PaddleOCR-VL achieves state-of-the-art performance in both page-level document parsing and element-level recognition. It significantly outperforms existing pipeline-based solutions and exhibiting strong competitiveness against leading vision-language models (VLMs) in document parsing. Moreover, it excels in recognizing complex document elements, such as text, tables, formulas, and charts, making it suitable for a wide range of challenging content types, including handwritten text and historical documents. This makes it highly versatile and suitable for a wide range of document types and scenarios. + +3. **Multilingual Support:** PaddleOCR-VL Supports 109 languages, covering major global languages, including but not limited to Chinese, English, Japanese, Latin, and Korean, as well as languages with different scripts and structures, such as Russian (Cyrillic script), Arabic, Hindi (Devanagari script), and Thai. This broad language coverage substantially enhances the applicability of our system to multilingual and globalized document processing scenarios. + +### **Model Architecture** + +
+ +
+ +## Usage examples + +### Single input inference + +The example below demonstrates how to generate text with PaddleOCRVL using [`Pipeline`] or the [`AutoModel`]. + + + + +```py +from transformers import pipeline + +pipe = pipeline("image-text-to-text", model="PaddlePaddle/PaddleOCR-VL", dtype="bfloat16") +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"}, + {"type": "text", "text": "OCR:"}, + ] + } +] +result = pipe(text=messages) +print(result[0]["generated_text"]) +``` + + + + + +```py +from transformers import AutoProcessor, AutoModelForImageTextToText + +model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16") +processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL") +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"}, + {"type": "text", "text": "OCR:"}, + ] + } +] +inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", +).to(model.device) + +outputs = model.generate(**inputs, max_new_tokens=100) +result = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:-1]) +print(result) +``` + + + + +### Batched inference + +PaddleOCRVL also supports batched inference. We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Here is how you can do it with PaddleOCRVL using [`Pipeline`] or the [`AutoModel`]: + + + + +```py +from transformers import pipeline + +pipe = pipeline("image-text-to-text", model="PaddlePaddle/PaddleOCR-VL", dtype="bfloat16") +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"}, + {"type": "text", "text": "OCR:"}, + ] + } +] +result = pipe(text=[messages, messages]) +print(result[0][0]["generated_text"]) +print(result[1][0]["generated_text"]) +``` + + + + + +```py +from transformers import AutoProcessor, AutoModelForImageTextToText + +model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16") +processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL") +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"}, + {"type": "text", "text": "OCR:"}, + ] + } +] +batch_messages = [messages, messages] +inputs = processor.apply_chat_template( + batch_messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + padding_side='left', +).to(model.device) + +generated_ids = model.generate(**inputs, max_new_tokens=100) +generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] +result = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) +print(result) +``` + + + + +### Using Flash Attention 2 + +Flash Attention 2 is an even faster, optimized version of the previous optimization, please refer to the [FlashAttention](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention). + +For example: + +```shell +pip install flash-attn --no-build-isolation +``` + +```python +from transformers import AutoModelForImageTextToText +model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16", attn_implementation="flash_attention_2") +``` + +## PaddleOCRVLConfig + +[[autodoc]] PaddleOCRVLConfig + +## PaddleOCRVLForConditionalGeneration + +[[autodoc]] PaddleOCRVLForConditionalGeneration + - forward From 4a3773416cfad995f8132c59b2c4a5f64dafd722 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Wed, 10 Dec 2025 18:02:38 +0800 Subject: [PATCH 09/19] update model_doc --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/paddleocr_vl.md | 25 +++++++-- src/transformers/conversion_mapping.py | 2 +- .../configuration_paddleocr_vl.py | 29 +++++------ .../paddleocr_vl/modular_paddleocr_vl.py | 51 +++++++++++++++++-- 5 files changed, 85 insertions(+), 24 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 5d95fc368285..e6b623935880 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1119,6 +1119,8 @@ title: OWL-ViT - local: model_doc/owlv2 title: OWLv2 + - local: model_doc/paddleocr_vl + title: PaddleOCRVL - local: model_doc/paligemma title: PaliGemma - local: model_doc/perceiver diff --git a/docs/source/en/model_doc/paddleocr_vl.md b/docs/source/en/model_doc/paddleocr_vl.md index 5d315c5407fe..d8be52f2d82d 100644 --- a/docs/source/en/model_doc/paddleocr_vl.md +++ b/docs/source/en/model_doc/paddleocr_vl.md @@ -1,6 +1,6 @@ -*This model was released on 2025 and added to Hugging Face Transformers on 2025.12.10* +*This model was released on 2025.10.16 and added to Hugging Face Transformers on 2025.12.10* -# PaddleOCRVL +# PaddleOCR-VL
PyTorch @@ -10,7 +10,9 @@ ## Overview -**🔥 Official Website**: [Baidu AI Studio](https://aistudio.baidu.com/paddleocr) | **📝 arXiv**: [Technical Report](https://arxiv.org/pdf/2510.14528) +**Huggingface Hub**: [PaddleOCR-VL](https://huggingface.co/collections/PaddlePaddle/paddleocr-vl) | **Github Repo**: [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) + +**Official Website**: [Baidu AI Studio](https://aistudio.baidu.com/paddleocr) | **arXiv**: [Technical Report](https://arxiv.org/pdf/2510.14528) **PaddleOCR-VL** is a SOTA and resource-efficient model tailored for document parsing. Its core component is PaddleOCR-VL-0.9B, a compact yet powerful vision-language model (VLM) that integrates a NaViT-style dynamic resolution visual encoder with the ERNIE-4.5-0.3B language model to enable accurate element recognition. This innovative model efficiently supports 109 languages and excels in recognizing complex elements (e.g., text, tables, formulas, and charts), while maintaining minimal resource consumption. Through comprehensive evaluations on widely used public benchmarks and in-house benchmarks, PaddleOCR-VL achieves SOTA performance in both page-level document parsing and element-level recognition. It significantly outperforms existing solutions, exhibits strong competitiveness against top-tier VLMs, and delivers fast inference speeds. These strengths make it highly suitable for practical deployment in real-world scenarios. @@ -32,7 +34,22 @@
-## Usage examples +## Usage + +### Usage tips + +> [!IMPORTANT] +> We currently recommend using the [PaddleOCR official method for inference](https://www.paddleocr.ai/latest/en/version3.x/pipeline_usage/PaddleOCR-VL.html), as it is faster and supports page-level document parsing. +> The example code below only supports element-level recognition. + +We have four types of element-level recognition: + +- Text recognition, indicated by the prompt `OCR:`. +- Formula recognition, indicated by the prompt `Formula Recognition:`. +- Table recognition, indicated by the prompt `Table Recognition:`. +- Chart recognition, indicated by the prompt `Chart Recognition:`. + +The following examples are all based on text recognition, with the prompt `OCR:`. ### Single input inference diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 23366d9dce9d..acb99b99f83e 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -213,7 +213,7 @@ def get_checkpoint_conversion_mapping(model_type): "sam3", "sam3_tracker", "sam3_tracker_video", - "paddleocrvl" + "paddleocrvl", ] diff --git a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py index cc8ff58dc97f..ac06f99c8284 100644 --- a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py @@ -256,27 +256,25 @@ def __init__( class PaddleOCRVLConfig(PreTrainedConfig): r""" - This is the configuration class to store the configuration of a [`PaddleOCRVLModel`]. It is used to instantiate a - Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + This is the configuration class to store the configuration of a [`PaddleOCRVLForConditionalGeneration`]. It is used to instantiate a + PaddleOCRVL model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of - Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + PaddleOCRVL [PaddlePaddle/PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL). Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PreTrainedConfig`] for more information. Args: - text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PaddleOCRVLTextConfig`): + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PaddleOCRTextConfig`): The config object or dictionary of the text backbone. - vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PaddleOCRVLVisionConfig`): + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PaddleOCRVisionConfig`): The config object or dictionary of the vision backbone. - image_token_id (`int`, *optional*, defaults to 151655): + image_token_id (`int`, *optional*, defaults to 100295): The image token index to encode the image prompt. - video_token_id (`int`, *optional*, defaults to 151656): - The video token index to encode the image prompt. - vision_start_token_id (`int`, *optional*, defaults to 151652): + vision_start_token_id (`int`, *optional*, defaults to 101305): The token index to denote start of vision input. - vision_end_token_id (`int`, *optional*, defaults to 151653): + vision_end_token_id (`int`, *optional*, defaults to 101306): The token index to denote end of vision input. ```python @@ -285,7 +283,7 @@ class PaddleOCRVLConfig(PreTrainedConfig): >>> # Initializing a PaddleOCRVL style configuration >>> configuration = PaddleOCRVLConfig() - >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> # Initializing a model from the PaddleOCRVL style configuration >>> model = PaddleOCRVLForConditionalGeneration(configuration) >>> # Accessing the model configuration @@ -293,6 +291,7 @@ class PaddleOCRVLConfig(PreTrainedConfig): ```""" model_type = "paddleocr_vl" + sub_configs = {"vision_config": PaddleOCRVisionConfig, "text_config": PaddleOCRTextConfig} keys_to_ignore_at_inference = ["past_key_values"] @@ -300,10 +299,9 @@ def __init__( self, text_config=None, vision_config=None, - image_token_id=151655, - video_token_id=151656, - vision_start_token_id=151652, - vision_end_token_id=151653, + image_token_id=100295, + vision_start_token_id=101305, + vision_end_token_id=101306, **kwargs, ): if isinstance(vision_config, dict): @@ -322,7 +320,6 @@ def __init__( self.text_config = self.sub_configs["text_config"](**text_config) self.image_token_id = image_token_id - self.video_token_id = video_token_id self.vision_start_token_id = vision_start_token_id self.vision_end_token_id = vision_end_token_id diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 01db11c4e2b9..ec68805a4b4a 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -411,9 +411,7 @@ def _preprocess( patch_size, ) patches = patches.permute(0, 1, 4, 6, 3, 2, 5, 7) - flatten_patches = patches.reshape( - batch_size, grid_t * grid_h * grid_w, channel, patch_size, patch_size - ) + flatten_patches = patches.reshape(batch_size, grid_t * grid_h * grid_w, channel, patch_size, patch_size) processed_images_grouped[shape] = flatten_patches processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size @@ -549,8 +547,55 @@ class PaddleOCRTextConfig(Ernie4_5Config): class PaddleOCRVLConfig(Qwen2VLConfig): + r""" + This is the configuration class to store the configuration of a [`PaddleOCRVLForConditionalGeneration`]. It is used to instantiate a + PaddleOCRVL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + PaddleOCRVL [PaddlePaddle/PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL). + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PaddleOCRTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PaddleOCRVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 100295): + The image token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 101305): + The token index to denote start of vision input. + vision_end_token_id (`int`, *optional*, defaults to 101306): + The token index to denote end of vision input. + + ```python + >>> from transformers import PaddleOCRVLForConditionalGeneration, PaddleOCRVLConfig + + >>> # Initializing a PaddleOCRVL style configuration + >>> configuration = PaddleOCRVLConfig() + + >>> # Initializing a model from the PaddleOCRVL style configuration + >>> model = PaddleOCRVLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + sub_configs = {"vision_config": PaddleOCRVisionConfig, "text_config": PaddleOCRTextConfig} + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=100295, + vision_start_token_id=101305, + vision_end_token_id=101306, + **kwargs, + ): + super().__init__() + del self.video_token_id + class PaddleOCRProjector(nn.Module): def __init__(self, config: PaddleOCRVLConfig): From 0aecf6f5e88708c015a7c1e4cf6b60c6c949a0f7 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Wed, 10 Dec 2025 18:56:26 +0800 Subject: [PATCH 10/19] fix ci --- .../configuration_paddleocr_vl.py | 37 ++++++------ .../paddleocr_vl/modeling_paddleocr_vl.py | 4 +- .../paddleocr_vl/modular_paddleocr_vl.py | 60 ++++++++++++++++++- 3 files changed, 80 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py index ac06f99c8284..92cef0f173e9 100644 --- a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py @@ -33,27 +33,27 @@ class PaddleOCRVisionConfig(PreTrainedConfig): r""" This is the configuration class to store the configuration of a [`PaddleOCRVisionModel`]. It is used to instantiate a - PaddleOCR vision encoder according to the specified arguments, defining the model architecture. Instantiating a - configuration with the defaults will yield a similar configuration to that of the vision encoder of the PaddleOCR - [google/paddle_o_c_r-base-patch16-224](https://huggingface.co/google/paddle_o_c_r-base-patch16-224) architecture. + PaddleOCRVL vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the PaddleOCRVL + [PaddlePaddle/PaddleOCRVL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL) architecture. Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PreTrainedConfig`] for more information. Args: - hidden_size (`int`, *optional*, defaults to 768): + hidden_size (`int`, *optional*, defaults to 1152): Dimensionality of the encoder layers and the pooler layer. - intermediate_size (`int`, *optional*, defaults to 3072): + intermediate_size (`int`, *optional*, defaults to 4304): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. - num_hidden_layers (`int`, *optional*, defaults to 12): + num_hidden_layers (`int`, *optional*, defaults to 27): Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 12): + num_attention_heads (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer encoder. num_channels (`int`, *optional*, defaults to 3): Number of channels in the input images. - image_size (`int`, *optional*, defaults to 224): + image_size (`int`, *optional*, defaults to 384): The size (resolution) of each image. - patch_size (`int`, *optional*, defaults to 16): + patch_size (`int`, *optional*, defaults to 14): The size (resolution) of each patch. hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, @@ -68,28 +68,29 @@ class PaddleOCRVisionConfig(PreTrainedConfig): ```python >>> from transformers import PaddleOCRVisionConfig, PaddleOCRVisionModel - >>> # Initializing a PaddleOCRVisionConfig with google/paddle_o_c_r-base-patch16-224 style configuration + >>> # Initializing a PaddleOCRVisionConfig with PaddlePaddle/PaddleOCR-VL style configuration >>> configuration = PaddleOCRVisionConfig() - >>> # Initializing a PaddleOCRVisionModel (with random weights) from the google/paddle_o_c_r-base-patch16-224 style configuration + >>> # Initializing a PaddleOCRVisionModel (with random weights) from the PaddlePaddle/PaddleOCR-VL style configuration >>> model = PaddleOCRVisionModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config - ```""" + ``` + """ model_type = "paddleocr_vl_vision" base_config_key = "vision_config" def __init__( self, - hidden_size=768, - intermediate_size=3072, - num_hidden_layers=12, - num_attention_heads=12, + hidden_size=1152, + intermediate_size=4304, + num_hidden_layers=27, + num_attention_heads=16, num_channels=3, - image_size=224, - patch_size=16, + image_size=384, + patch_size=14, hidden_act="gelu_pytorch_tanh", layer_norm_eps=1e-6, attention_dropout=0.0, diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index ea6f2fb91a3e..0c95216abbc5 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -42,7 +42,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_paddleocr_vl import PaddleOCRTextConfig, PaddleOCRVisionConfig, PaddleOCRVLConfig @@ -156,7 +156,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index ec68805a4b4a..419bae7da2d2 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -528,14 +528,72 @@ def __call__( class PaddleOCRVisionConfig(SiglipVisionConfig): + r""" + This is the configuration class to store the configuration of a [`PaddleOCRVisionModel`]. It is used to instantiate a + PaddleOCRVL vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the PaddleOCRVL + [PaddlePaddle/PaddleOCRVL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL) architecture. + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 4304): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 27): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 384): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + Example: + + ```python + >>> from transformers import PaddleOCRVisionConfig, PaddleOCRVisionModel + + >>> # Initializing a PaddleOCRVisionConfig with PaddlePaddle/PaddleOCR-VL style configuration + >>> configuration = PaddleOCRVisionConfig() + + >>> # Initializing a PaddleOCRVisionModel (with random weights) from the PaddlePaddle/PaddleOCR-VL style configuration + >>> model = PaddleOCRVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + model_type = "paddleocr_vl_vision" base_config_key = "vision_config" def __init__( self, + hidden_size=1152, + intermediate_size=4304, + num_hidden_layers=27, + num_attention_heads=16, + num_channels=3, + image_size=384, + patch_size=14, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, spatial_merge_size=2, temporal_patch_size=2, - **super_kwargs, + **kwargs, ): super().__init__() self.spatial_merge_size = spatial_merge_size From 2665f9c416c780c9b992f92c54b1ce7c6a3a5726 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Wed, 10 Dec 2025 19:28:22 +0800 Subject: [PATCH 11/19] update docstring --- .../configuration_paddleocr_vl.py | 4 +++ .../image_processing_paddleocr_vl.py | 22 ++++++++++++ .../image_processing_paddleocr_vl_fast.py | 9 +++++ .../paddleocr_vl/modular_paddleocr_vl.py | 35 +++++++++++++++++++ 4 files changed, 70 insertions(+) diff --git a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py index 92cef0f173e9..9aceafedb7a3 100644 --- a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py @@ -62,6 +62,10 @@ class PaddleOCRVisionConfig(PreTrainedConfig): The epsilon used by the layer normalization layers. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. + spatial_merge_size (`int`, *optional*, defaults to 2): + The size used for merging spatial dimensions. + temporal_patch_size (`int`, *optional*, defaults to 2): + The size used for patches along the temporal dimension. Example: diff --git a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py index 9f9bc1417959..101668ea0335 100644 --- a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py @@ -110,12 +110,34 @@ class PaddleOCRVLImageProcessor(BaseImageProcessor): Constructs a PaddleOCRVL image processor that dynamically resizes images based on the original images. Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. + size (`dict[str, int]`, *optional*): + Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*): + Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. + image_std (`float` or `list[float]`, *optional*): + Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. min_pixels (`int`, *optional*, defaults to `384 * 384`): The min pixels of the image to resize the image. max_pixels (`int`, *optional*, defaults to `1536 * 1536`): The max pixels of the image to resize the image. + patch_size (`int`, *optional*, defaults to 14): + The spatial patch size of the vision encoder. temporal_patch_size (`int`, *optional*, defaults to 1): The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. """ model_input_names = [ diff --git a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py index dc846c7165ea..cf6a65889b97 100644 --- a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +++ b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py @@ -79,6 +79,8 @@ def __init__( image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, do_convert_rgb: bool = True, + min_pixels: int = 384 * 384, + max_pixels: int = 1536 * 1536, patch_size: int = 14, temporal_patch_size: int = 1, merge_size: int = 2, @@ -89,9 +91,15 @@ def __init__( raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") else: size = {"shortest_edge": 384 * 384, "longest_edge": 1536 * 1536} + # backward compatibility: override size with min_pixels and max_pixels if they are provided + if min_pixels is not None: + size["shortest_edge"] = min_pixels + if max_pixels is not None: + size["longest_edge"] = max_pixels self.min_pixels = size["shortest_edge"] self.max_pixels = size["longest_edge"] self.size = size + self.do_resize = do_resize self.resample = resample self.do_rescale = do_rescale @@ -99,6 +107,7 @@ def __init__( self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.merge_size = merge_size diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 419bae7da2d2..827f082e4100 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -126,12 +126,34 @@ class PaddleOCRVLImageProcessor(Qwen2VLImageProcessor): Constructs a PaddleOCRVL image processor that dynamically resizes images based on the original images. Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. + size (`dict[str, int]`, *optional*): + Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*): + Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. + image_std (`float` or `list[float]`, *optional*): + Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. min_pixels (`int`, *optional*, defaults to `384 * 384`): The min pixels of the image to resize the image. max_pixels (`int`, *optional*, defaults to `1536 * 1536`): The max pixels of the image to resize the image. + patch_size (`int`, *optional*, defaults to 14): + The spatial patch size of the vision encoder. temporal_patch_size (`int`, *optional*, defaults to 1): The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. """ model_input_names = [ @@ -308,6 +330,8 @@ def __init__( image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, do_convert_rgb: bool = True, + min_pixels: int = 384 * 384, + max_pixels: int = 1536 * 1536, patch_size: int = 14, temporal_patch_size: int = 1, merge_size: int = 2, @@ -318,9 +342,15 @@ def __init__( raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") else: size = {"shortest_edge": 384 * 384, "longest_edge": 1536 * 1536} + # backward compatibility: override size with min_pixels and max_pixels if they are provided + if min_pixels is not None: + size["shortest_edge"] = min_pixels + if max_pixels is not None: + size["longest_edge"] = max_pixels self.min_pixels = size["shortest_edge"] self.max_pixels = size["longest_edge"] self.size = size + self.do_resize = do_resize self.resample = resample self.do_rescale = do_rescale @@ -328,6 +358,7 @@ def __init__( self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.merge_size = merge_size @@ -559,6 +590,10 @@ class PaddleOCRVisionConfig(SiglipVisionConfig): The epsilon used by the layer normalization layers. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. + spatial_merge_size (`int`, *optional*, defaults to 2): + The size used for merging spatial dimensions. + temporal_patch_size (`int`, *optional*, defaults to 2): + The size used for patches along the temporal dimension. Example: From da318471a24c700878edd1fa0eac955c40235566 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Wed, 10 Dec 2025 19:38:27 +0800 Subject: [PATCH 12/19] add tests --- .../paddleocr_vl/modeling_paddleocr_vl.py | 9 +- .../paddleocr_vl/modular_paddleocr_vl.py | 3 + tests/models/paddleocr_vl/__init__.py | 0 .../test_modeling_paddleocr_vl.py | 515 ++++++++++++++++++ 4 files changed, 526 insertions(+), 1 deletion(-) create mode 100644 tests/models/paddleocr_vl/__init__.py create mode 100644 tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 0c95216abbc5..dbdedbf26257 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -1661,4 +1661,11 @@ def _expand_dict_for_generation(dict_to_expand): return input_ids, model_kwargs -__all__ = ["PaddleOCRVLForConditionalGeneration", "PaddleOCRTextModel", "PaddleOCRVisionModel"] +__all__ = [ + "PaddleOCRVLForConditionalGeneration", + "PaddleOCRVLModel", + "PaddleOCRVLPreTrainedModel", + "PaddleOCRVisionTransformer", + "PaddleOCRTextModel", + "PaddleOCRVisionModel", +] diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 827f082e4100..e1e30452f35b 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -1340,6 +1340,9 @@ def forward( __all__ = [ "PaddleOCRVLForConditionalGeneration", + "PaddleOCRVLModel", + "PaddleOCRVLPreTrainedModel", + "PaddleOCRVisionTransformer", "PaddleOCRVLConfig", "PaddleOCRTextModel", "PaddleOCRVisionModel", diff --git a/tests/models/paddleocr_vl/__init__.py b/tests/models/paddleocr_vl/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py b/tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py new file mode 100644 index 000000000000..ff10d96d9fb9 --- /dev/null +++ b/tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py @@ -0,0 +1,515 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PaddleOCRVL model.""" + +import copy +import gc +import unittest + +import pytest +from parameterized import parameterized + +from transformers import ( + AutoProcessor, + PaddleOCRVLConfig, + PaddleOCRVLForConditionalGeneration, + is_torch_available, +) +from transformers.testing_utils import ( + backend_empty_cache, + require_flash_attn, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + ModelTesterMixin, + floats_tensor, + ids_tensor, +) +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + +class PaddleOCRVLVisionText2TextModelTester: + def __init__( + self, + parent, + batch_size=7, + seq_length=13, + num_channels=3, + image_height=252, + image_width=616, + text_config={ + "pad_token_id": 0, + "bos_token_id": 1, + "eos_token_id": 2, + "vocab_size": 103424, + "head_dim": 128, + "hidden_act": "silu", + "hidden_dropout_prob": 0.0, + "hidden_size": 1024, + "ignored_index": -100, + "image_token_id": 100295, + "intermediate_size": 3072, + "max_position_embeddings": 131072, + "model_type": "paddleocr_vl", + "num_attention_heads": 16, + "num_hidden_layers": 18, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-05, + "rope_scaling": {"mrope_section": [16, 24, 24], "rope_type": "default", "type": "default"}, + "rope_theta": 500000, + "tie_word_embeddings": False, + }, + vision_start_token_id=101305, + vision_end_token_id=101306, + image_token_id=100295, + is_training=True, + vision_config={ + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "intermediate_size": 4304, + "layer_norm_eps": 1e-06, + "model_type": "paddleocr_vl", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 27, + "pad_token_id": 0, + "patch_size": 14, + "spatial_merge_size": 2, + # "torch_dtype": "bfloat16" + }, + ): + self.parent = parent + self.bos_token_id = text_config["bos_token_id"] + self.eos_token_id = text_config["eos_token_id"] + self.pad_token_id = text_config["pad_token_id"] + self.num_hidden_layers = text_config["num_hidden_layers"] + self.num_attention_heads = text_config["num_attention_heads"] + self.hidden_size = text_config["hidden_size"] + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + self.image_token_id = image_token_id + self.text_config = text_config + self.vision_config = vision_config + self.batch_size = batch_size + self.num_channels = num_channels + self.image_height = image_height + self.image_width = image_width + self.is_training = is_training + self.vocab_size = text_config["vocab_size"] + self.num_image_tokens = 198 + self.seq_length = seq_length + self.num_image_tokens + + def get_config(self): + return PaddleOCRVLConfig( + text_config=self.text_config, + vision_config=self.vision_config, + vision_start_token_id=self.vision_start_token_id, + image_token_id=self.image_token_id, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + patch_size = config.vision_config.patch_size + pixel_values = floats_tensor( + [ + self.batch_size * (self.image_height * self.image_width) // (patch_size**2), + config.vision_config.num_channels, + patch_size, + patch_size, + ] + ) + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + input_ids[:, :4] = torch.tensor([100273, 2969, 93963, 93919], dtype=input_ids.dtype, device=input_ids.device) + input_ids[:, 4] = self.vision_start_token_id + input_ids[:, 5 : 5 + self.num_image_tokens] = self.image_token_id + input_ids[:, -8] = self.vision_end_token_id + input_ids[:, -7:] = torch.tensor( + [93972, 2497, 93963, 23, 92267, 93963, 93919], dtype=input_ids.dtype, device=input_ids.device + ) + + inputs_dict = { + "pixel_values": pixel_values, + "image_grid_thw": torch.tensor([[1, 18, 44]] * self.batch_size, device=torch_device), + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class PaddleOCRVLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Model tester for `PaddleOCRVLForConditionalGeneration`. + """ + + all_model_classes = (PaddleOCRVLForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {"image-text-to-text": PaddleOCRVLForConditionalGeneration} + _is_composite = True + + def setUp(self): + self.model_tester = PaddleOCRVLVisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=PaddleOCRVLConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_mismatching_num_image_tokens(self): + """ + Tests that an explicit error is thrown when the number of image tokens + doesn't match the number of image placeholders in the text. + We also test multi-image cases when one prompt has multiple image tokens. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model.eval() + curr_input_dict = copy.deepcopy(input_dict) # in-place modifications further + _ = model(**curr_input_dict) # successful forward with no modifications + + # remove one image but leave all the image tokens in text + patch_size = config.vision_config.patch_size + one_img_length = (self.model_tester.image_height * self.model_tester.image_width) // (patch_size**2) + curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-one_img_length:, ...] + curr_input_dict["image_grid_thw"] = curr_input_dict["image_grid_thw"][-1:, ...] + with self.assertRaises(ValueError): + _ = model(**curr_input_dict) + + # simulate multi-image case by concatenating inputs where each has exactly one image/image-token + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values"][:one_img_length] + image_grid_thw = curr_input_dict["image_grid_thw"][:1] + input_ids = torch.cat([input_ids, input_ids], dim=0) + + # one image and two image tokens raise an error + with self.assertRaises(ValueError): + _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw) + + # two images and two image tokens don't raise an error + pixel_values = torch.cat([pixel_values, pixel_values], dim=0) + image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0) + _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw) + + def prepare_config_and_inputs_for_generate(self, batch_size=2): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # We don't want a few model inputs in our model input dictionary for generation tests + input_keys_to_ignore = [ + # we don't want encoder-decoder models to start from filled decoder ids + "decoder_input_ids", + "decoder_attention_mask", + # we'll set cache use in each test differently + "use_cache", + # Ignore labels if it is in the input dict + "labels", + # model-specific exceptions should overload/overwrite this function + ] + + # The diff from the general `prepare_config_and_inputs_for_generate` lies here + patch_size = config.vision_config.patch_size + filtered_image_length = ( + batch_size * (self.model_tester.image_height * self.model_tester.image_width) // (patch_size**2) + ) + filtered_inputs_dict = { + k: v[:batch_size, ...] if isinstance(v, torch.Tensor) else v + for k, v in inputs_dict.items() + if k not in input_keys_to_ignore + } + filtered_inputs_dict["pixel_values"] = inputs_dict["pixel_values"][:filtered_image_length] + + # It is important set `eos_token_id` to `None` to avoid early stopping (would break for length-based checks) + text_gen_config = config.get_text_config(decoder=True) + if text_gen_config.eos_token_id is not None and text_gen_config.pad_token_id is None: + text_gen_config.pad_token_id = ( + text_gen_config.eos_token_id + if isinstance(text_gen_config.eos_token_id, int) + else text_gen_config.eos_token_id[0] + ) + text_gen_config.eos_token_id = None + text_gen_config.forced_eos_token_id = None + + return config, filtered_inputs_dict + + @unittest.skip(reason="PaddleOCRVL does not support.") + def test_generate_compile_model_forward_fullgraph(self): + pass + + @unittest.skip(reason="PaddleOCRVL does not support.") + def test_multi_gpu_data_parallel_forward(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="PaddleOCRVL does not support beam search.") + def test_beam_sample_generate(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="PaddleOCRVL does not support beam search.") + def test_beam_search_generate(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="PaddleOCRVL does not support beam search.") + def test_beam_search_generate_dict_output(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="PaddleOCRVL does not support beam search.") + def test_beam_search_generate_dict_outputs_use_cache(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="PaddleOCRVL does not support beam search.") + def test_beam_sample_generate_dict_output(self): + pass + + @unittest.skip(reason="PaddleOCRVL needs to apply weight conversions.") + def test_can_load_from_already_mapped_keys(self): + pass + + @unittest.skip(reason="PaddleOCRVL needs to apply weight conversions.") + def test_from_pretrained_no_checkpoint(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="PaddleOCRVL does not support beam search.") + def test_generate_from_inputs_embeds_1_beam_search(self, _, num_beams): + pass + + @parameterized.expand([("random",), ("same",)]) + @pytest.mark.generate + @unittest.skip(reason="PaddleOCRVL does not support assisted decoding.") + def test_assisted_decoding_matches_greedy_search(self, assistant_type): + pass + + @pytest.mark.generate + @unittest.skip(reason="PaddleOCRVL does not support assisted decoding.") + def test_assisted_decoding_sample(self): + pass + + @unittest.skip("PaddleOCRVL does not support this test.") + def test_model_is_small(self): + pass + + @unittest.skip("PaddleOCRVL does not support this test.") + def test_num_layers_is_small(self): + pass + + +@require_torch +@slow +class PaddleOCRVLIntegrationTest(unittest.TestCase): + def setUp(self): + self.processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL") + self.messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg", + }, + {"type": "text", "text": "OCR:"}, + ], + } + ] + + def tearDown(self): + gc.collect() + backend_empty_cache(torch_device) + + def test_small_model_integration_test(self): + model = ( + PaddleOCRVLForConditionalGeneration.from_pretrained( + "PaddlePaddle/PaddleOCR-VL", + dtype="bfloat16", + ) + .to(torch_device) + .eval() + ) + + inputs = self.processor.apply_chat_template( + self.messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + expected_input_ids_length = 211 + assert expected_input_ids_length == len(inputs.input_ids[0]) + + expected_input_ids = [100273, 2969, 93963, 93919, 101305, 100295, 100295, 100295, 100295, 100295] # fmt: skip + assert expected_input_ids == inputs.input_ids[0].tolist()[:10] + + expected_pixel_slice = torch.tensor( + [ + [1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000], + [0.9922, 0.9922, 0.9922], + [1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000], + ], + dtype=torch.float32, + device="cpu", + ) + + assert torch.allclose(expected_pixel_slice, inputs.pixel_values[:5, :, 0, 0], atol=3e-3) + + # verify generation + inputs = inputs.to(torch_device) + output = model.generate(**inputs, max_new_tokens=30) + result = self.processor.decode(output[0][inputs["input_ids"].shape[-1] : -1]) + + EXPECTED_DECODED_TEXT = "生甘草" + + self.assertEqual( + result, + EXPECTED_DECODED_TEXT, + ) + + def test_small_model_integration_test_batch(self): + model = ( + PaddleOCRVLForConditionalGeneration.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16") + .to(torch_device) + .eval() + ) + + inputs = self.processor.apply_chat_template( + [self.messages, self.messages], + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + padding_side="left", + ).to(torch_device) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30) + generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, output)] + result = self.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + + EXPECTED_DECODED_TEXT = ["生甘草", "生甘草"] + + self.assertEqual( + result, + EXPECTED_DECODED_TEXT, + ) + + @require_flash_attn + @require_torch_accelerator + @pytest.mark.flash_attn_test + def test_small_model_integration_test_flashatt2(self): + model = ( + PaddleOCRVLForConditionalGeneration.from_pretrained( + "PaddlePaddle/PaddleOCR-VL", dtype="bfloat16", attn_implementation="flash_attention_2" + ) + .to(torch_device) + .eval() + ) + + inputs = self.processor.apply_chat_template( + self.messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + expected_input_ids_length = 211 + assert expected_input_ids_length == len(inputs.input_ids[0]) + + expected_input_ids = [100273, 2969, 93963, 93919, 101305, 100295, 100295, 100295, 100295, 100295] # fmt: skip + assert expected_input_ids == inputs.input_ids[0].tolist()[:10] + + expected_pixel_slice = torch.tensor( + [ + [1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000], + [0.9922, 0.9922, 0.9922], + [1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000], + ], + dtype=torch.float32, + device="cpu", + ) + assert torch.allclose(expected_pixel_slice, inputs.pixel_values[:5, :, 0, 0], atol=3e-3) + + # verify generation + inputs = inputs.to(torch_device) + output = model.generate(**inputs, max_new_tokens=30) + result = self.processor.decode(output[0][inputs["input_ids"].shape[-1] : -1]) + + EXPECTED_DECODED_TEXT = "生甘草" + + self.assertEqual( + result, + EXPECTED_DECODED_TEXT, + ) + + @require_flash_attn + @require_torch_accelerator + @pytest.mark.flash_attn_test + def test_small_model_integration_test_batch_flashatt2(self): + model = ( + PaddleOCRVLForConditionalGeneration.from_pretrained( + "PaddlePaddle/PaddleOCR-VL", dtype="bfloat16", attn_implementation="flash_attention_2" + ) + .to(torch_device) + .eval() + ) + + inputs = self.processor.apply_chat_template( + [self.messages, self.messages], + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + padding_side="left", + ).to(torch_device) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30) + generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, output)] + result = self.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + + EXPECTED_DECODED_TEXT = ["生甘草", "生甘草"] + + self.assertEqual( + result, + EXPECTED_DECODED_TEXT, + ) From a8b624ec562d7d09d851ba17afc9426f54ec8cd1 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Wed, 10 Dec 2025 19:58:22 +0800 Subject: [PATCH 13/19] update --- docs/source/en/model_doc/paddleocr_vl.md | 42 ++++++++++++++++++++++-- utils/check_repo.py | 8 +++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/docs/source/en/model_doc/paddleocr_vl.md b/docs/source/en/model_doc/paddleocr_vl.md index d8be52f2d82d..4077b9a960ce 100644 --- a/docs/source/en/model_doc/paddleocr_vl.md +++ b/docs/source/en/model_doc/paddleocr_vl.md @@ -187,11 +187,47 @@ from transformers import AutoModelForImageTextToText model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16", attn_implementation="flash_attention_2") ``` +## PaddleOCRVLForConditionalGeneration + +[[autodoc]] PaddleOCRVLForConditionalGeneration + - forward + ## PaddleOCRVLConfig [[autodoc]] PaddleOCRVLConfig -## PaddleOCRVLForConditionalGeneration +## PaddleOCRVisionConfig -[[autodoc]] PaddleOCRVLForConditionalGeneration - - forward +[[autodoc]] PaddleOCRVisionConfig + +## PaddleOCRTextConfig + +[[autodoc]] PaddleOCRTextConfig + +## PaddleOCRTextModel + +[[autodoc]] PaddleOCRTextModel + +## PaddleOCRVisionModel + +[[autodoc]] PaddleOCRVisionModel + +## PaddleOCRVLImageProcessor + +[[autodoc]] PaddleOCRVLImageProcessor + +## PaddleOCRVLImageProcessorFast + +[[autodoc]] PaddleOCRVLImageProcessorFast + +## PaddleOCRVLModel + +[[autodoc]] PaddleOCRVLModel + +## PaddleOCRVLProcessor + +[[autodoc]] PaddleOCRVLProcessor + +## PaddleOCRVisionTransformer + +[[autodoc]] PaddleOCRVisionTransformer diff --git a/utils/check_repo.py b/utils/check_repo.py index 80d6a3f3223f..651e6726ec44 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -153,6 +153,10 @@ "SeamlessM4TCodeHifiGan", # Building part of bigger (tested) model. "SeamlessM4TTextToUnitForConditionalGeneration", # Building part of bigger (tested) model. "ChameleonVQVAE", # VQVAE here is used only for encoding (discretizing) and is tested as part of bigger model + "PaddleOCRVLModel", # Building part of bigger (tested) model. Tested implicitly through PaddleOCRVLForConditionalGeneration. + "PaddleOCRVisionModel", # Building part of bigger (tested) model. Tested implicitly through PaddleOCRVLForConditionalGeneration. + "PaddleOCRVisionTransformer", # Building part of bigger (tested) model. Tested implicitly through PaddleOCRVLForConditionalGeneration. + "PaddleOCRTextModel", # Building part of bigger (tested) model. Tested implicitly through PaddleOCRVLForConditionalGeneration. "Qwen2VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2VLForConditionalGeneration. "Qwen2_5_VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5_VLForConditionalGeneration. "Qwen3VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3VLForConditionalGeneration. @@ -382,6 +386,10 @@ "Emu3TextModel", # Building part of bigger (tested) model "JanusVQVAE", # no autoclass for VQ-VAE models "JanusVisionModel", # Building part of bigger (tested) model + "PaddleOCRVLModel", # Building part of bigger (tested) model + "PaddleOCRVisionModel", # Building part of bigger (tested) model + "PaddleOCRVisionTransformer", # Building part of bigger (tested) model + "PaddleOCRTextModel", # Building part of bigger (tested) model "Qwen2_5OmniTalkerForConditionalGeneration", # Building part of a bigger model "Qwen2_5OmniTalkerModel", # Building part of a bigger model "Qwen2_5OmniThinkerForConditionalGeneration", # Building part of a bigger model From 988ebce1ee9c739c981f40b554f75e980bf58c62 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Wed, 10 Dec 2025 20:03:22 +0800 Subject: [PATCH 14/19] add **kwargs --- src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py | 2 ++ src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index dbdedbf26257..52d5ce28da06 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -551,6 +551,7 @@ def forward( pixel_values: torch.FloatTensor, cu_seqlens: torch.Tensor, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, + **kwargs, ) -> BaseModelOutputWithPooling: """ Args: @@ -924,6 +925,7 @@ def forward( cu_seqlens: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, + **kwargs, ) -> BaseModelOutputWithPooling: """ Args: diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index e1e30452f35b..2395361f30fb 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -871,6 +871,7 @@ def forward( pixel_values: torch.FloatTensor, cu_seqlens: torch.Tensor, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, + **kwargs, ) -> BaseModelOutputWithPooling: """ Args: @@ -1044,6 +1045,7 @@ def forward( cu_seqlens: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None, + **kwargs, ) -> BaseModelOutputWithPooling: """ Args: From f36c7809fd3997bfa89ef727dec439995fbe6f9e Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Wed, 10 Dec 2025 20:38:31 +0800 Subject: [PATCH 15/19] update --- .../models/paddleocr_vl/configuration_paddleocr_vl.py | 4 ---- src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py | 4 ---- utils/check_config_attributes.py | 1 + 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py index 9aceafedb7a3..454b4241c9e5 100644 --- a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py @@ -64,8 +64,6 @@ class PaddleOCRVisionConfig(PreTrainedConfig): The dropout ratio for the attention probabilities. spatial_merge_size (`int`, *optional*, defaults to 2): The size used for merging spatial dimensions. - temporal_patch_size (`int`, *optional*, defaults to 2): - The size used for patches along the temporal dimension. Example: @@ -99,7 +97,6 @@ def __init__( layer_norm_eps=1e-6, attention_dropout=0.0, spatial_merge_size=2, - temporal_patch_size=2, **kwargs, ): super().__init__(**kwargs) @@ -115,7 +112,6 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act self.spatial_merge_size = spatial_merge_size - self.temporal_patch_size = temporal_patch_size class PaddleOCRTextConfig(PreTrainedConfig): diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 2395361f30fb..e3b00a1798d9 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -592,8 +592,6 @@ class PaddleOCRVisionConfig(SiglipVisionConfig): The dropout ratio for the attention probabilities. spatial_merge_size (`int`, *optional*, defaults to 2): The size used for merging spatial dimensions. - temporal_patch_size (`int`, *optional*, defaults to 2): - The size used for patches along the temporal dimension. Example: @@ -627,12 +625,10 @@ def __init__( layer_norm_eps=1e-6, attention_dropout=0.0, spatial_merge_size=2, - temporal_patch_size=2, **kwargs, ): super().__init__() self.spatial_merge_size = spatial_merge_size - self.temporal_patch_size = temporal_patch_size class PaddleOCRTextConfig(Ernie4_5Config): diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index ac379f618823..41606b1e1b6b 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -58,6 +58,7 @@ "expert_layer_offset", "expert_layer_period", ], + "PaddleOCRTextConfig": ["tie_word_embeddings"], "Qwen2Config": ["use_sliding_window", "max_window_layers"], "Qwen2MoeConfig": ["use_sliding_window", "max_window_layers"], "Qwen2VLTextConfig": ["use_sliding_window", "max_window_layers"], From 0ff29a5697f973f7117c83e867746803b7109720 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Wed, 10 Dec 2025 20:50:12 +0800 Subject: [PATCH 16/19] update --- .../models/paddleocr_vl/configuration_paddleocr_vl.py | 4 ++++ src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py index 454b4241c9e5..21e4db0d268c 100644 --- a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py @@ -273,6 +273,8 @@ class PaddleOCRVLConfig(PreTrainedConfig): The config object or dictionary of the vision backbone. image_token_id (`int`, *optional*, defaults to 100295): The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 100296): + The video token index to encode the image prompt. vision_start_token_id (`int`, *optional*, defaults to 101305): The token index to denote start of vision input. vision_end_token_id (`int`, *optional*, defaults to 101306): @@ -301,6 +303,7 @@ def __init__( text_config=None, vision_config=None, image_token_id=100295, + video_token_id=100296, vision_start_token_id=101305, vision_end_token_id=101306, **kwargs, @@ -321,6 +324,7 @@ def __init__( self.text_config = self.sub_configs["text_config"](**text_config) self.image_token_id = image_token_id + self.video_token_id = video_token_id self.vision_start_token_id = vision_start_token_id self.vision_end_token_id = vision_end_token_id diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index e3b00a1798d9..6598e1b3ff94 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -653,6 +653,8 @@ class PaddleOCRVLConfig(Qwen2VLConfig): The config object or dictionary of the vision backbone. image_token_id (`int`, *optional*, defaults to 100295): The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 100296): + The video token index to encode the image prompt. vision_start_token_id (`int`, *optional*, defaults to 101305): The token index to denote start of vision input. vision_end_token_id (`int`, *optional*, defaults to 101306): @@ -678,12 +680,12 @@ def __init__( text_config=None, vision_config=None, image_token_id=100295, + video_token_id=100296, vision_start_token_id=101305, vision_end_token_id=101306, **kwargs, ): super().__init__() - del self.video_token_id class PaddleOCRProjector(nn.Module): From a59aa8a6407bd6e3af1fce0995c6b385c692e32a Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Thu, 11 Dec 2025 11:54:47 +0800 Subject: [PATCH 17/19] update --- docs/source/en/model_doc/paddleocr_vl.md | 15 +++++++++++++ .../models/auto/tokenization_auto.py | 1 + .../paddleocr_vl/modeling_paddleocr_vl.py | 2 +- .../paddleocr_vl/modular_paddleocr_vl.py | 2 +- .../test_modeling_paddleocr_vl.py | 22 ++++++------------- 5 files changed, 25 insertions(+), 17 deletions(-) diff --git a/docs/source/en/model_doc/paddleocr_vl.md b/docs/source/en/model_doc/paddleocr_vl.md index 4077b9a960ce..c2e7417424fd 100644 --- a/docs/source/en/model_doc/paddleocr_vl.md +++ b/docs/source/en/model_doc/paddleocr_vl.md @@ -1,3 +1,18 @@ + *This model was released on 2025.10.16 and added to Hugging Face Transformers on 2025.12.10* # PaddleOCR-VL diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 1eaf5bad9202..39ef78aa5056 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -273,6 +273,7 @@ ("ovis2", "Qwen2TokenizerFast" if is_tokenizers_available() else None), ("owlv2", "CLIPTokenizerFast" if is_tokenizers_available() else None), ("owlvit", "CLIPTokenizerFast" if is_tokenizers_available() else None), + ("paddleocr_vl", "LlamaTokenizer" if is_tokenizers_available() else None), ("paligemma", "LlamaTokenizer" if is_tokenizers_available() else None), ("pegasus", "PegasusTokenizer" if is_tokenizers_available() else None), ("pegasus_x", "PegasusTokenizer" if is_tokenizers_available() else None), diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 52d5ce28da06..7c6b06caf8d7 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -1334,7 +1334,7 @@ class PaddleOCRVLForConditionalGeneration(PaddleOCRVLPreTrainedModel, Generation _checkpoint_conversion_mapping = { "^visual": "model.visual", "^mlp_AR": "model.projector", - r"^model(?!(\.visual|\.projector))": "model.language_model", + r"^model(?!(\.visual|\.projector|\.language_model))": "model.language_model", } _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} _keys_to_ignore_on_load_unexpected = ["packing_position_embedding", "vision_model.head"] diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 6598e1b3ff94..9860c21fc1fe 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -1235,7 +1235,7 @@ class PaddleOCRVLForConditionalGeneration(Qwen2VLForConditionalGeneration): _checkpoint_conversion_mapping = { "^visual": "model.visual", "^mlp_AR": "model.projector", - r"^model(?!(\.visual|\.projector))": "model.language_model", + r"^model(?!(\.visual|\.projector|\.language_model))": "model.language_model", } _keys_to_ignore_on_load_unexpected = ["packing_position_embedding", "vision_model.head"] diff --git a/tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py b/tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py index ff10d96d9fb9..f82c506bae86 100644 --- a/tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py +++ b/tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py @@ -56,8 +56,8 @@ def __init__( batch_size=7, seq_length=13, num_channels=3, - image_height=252, - image_width=616, + image_height=28, + image_width=28, text_config={ "pad_token_id": 0, "bos_token_id": 1, @@ -73,7 +73,7 @@ def __init__( "max_position_embeddings": 131072, "model_type": "paddleocr_vl", "num_attention_heads": 16, - "num_hidden_layers": 18, + "num_hidden_layers": 2, "num_key_value_heads": 2, "rms_norm_eps": 1e-05, "rope_scaling": {"mrope_section": [16, 24, 24], "rope_type": "default", "type": "default"}, @@ -92,11 +92,10 @@ def __init__( "model_type": "paddleocr_vl", "num_attention_heads": 16, "num_channels": 3, - "num_hidden_layers": 27, + "num_hidden_layers": 2, "pad_token_id": 0, "patch_size": 14, "spatial_merge_size": 2, - # "torch_dtype": "bfloat16" }, ): self.parent = parent @@ -117,7 +116,7 @@ def __init__( self.image_width = image_width self.is_training = is_training self.vocab_size = text_config["vocab_size"] - self.num_image_tokens = 198 + self.num_image_tokens = 1 self.seq_length = seq_length + self.num_image_tokens def get_config(self): @@ -158,7 +157,7 @@ def prepare_config_and_inputs_for_common(self): inputs_dict = { "pixel_values": pixel_values, - "image_grid_thw": torch.tensor([[1, 18, 44]] * self.batch_size, device=torch_device), + "image_grid_thw": torch.tensor([[1, 2, 2]] * self.batch_size, device=torch_device), "input_ids": input_ids, "attention_mask": attention_mask, } @@ -218,6 +217,7 @@ def test_mismatching_num_image_tokens(self): image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0) _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw) + # PaddleOCRVL has pixel_values shaped as (bs*patch_len, image_channels, patch_size, patch_size) so we can't slice to batches in generate def prepare_config_and_inputs_for_generate(self, batch_size=2): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -295,10 +295,6 @@ def test_beam_sample_generate_dict_output(self): def test_can_load_from_already_mapped_keys(self): pass - @unittest.skip(reason="PaddleOCRVL needs to apply weight conversions.") - def test_from_pretrained_no_checkpoint(self): - pass - @pytest.mark.generate @unittest.skip(reason="PaddleOCRVL does not support beam search.") def test_generate_from_inputs_embeds_1_beam_search(self, _, num_beams): @@ -319,10 +315,6 @@ def test_assisted_decoding_sample(self): def test_model_is_small(self): pass - @unittest.skip("PaddleOCRVL does not support this test.") - def test_num_layers_is_small(self): - pass - @require_torch @slow From 35be830687c9759ed4966a0b6356ee5e1a6715dc Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Thu, 11 Dec 2025 19:39:42 +0800 Subject: [PATCH 18/19] reduce max_position_embeddings in tests --- docs/source/en/model_doc/paddleocr_vl.md | 2 +- tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/paddleocr_vl.md b/docs/source/en/model_doc/paddleocr_vl.md index c2e7417424fd..cc3b1f4ced6d 100644 --- a/docs/source/en/model_doc/paddleocr_vl.md +++ b/docs/source/en/model_doc/paddleocr_vl.md @@ -1,4 +1,4 @@ -