diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 5d95fc368285..e6b623935880 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -1119,6 +1119,8 @@
title: OWL-ViT
- local: model_doc/owlv2
title: OWLv2
+ - local: model_doc/paddleocr_vl
+ title: PaddleOCRVL
- local: model_doc/paligemma
title: PaliGemma
- local: model_doc/perceiver
diff --git a/docs/source/en/model_doc/paddleocr_vl.md b/docs/source/en/model_doc/paddleocr_vl.md
new file mode 100644
index 000000000000..cc3b1f4ced6d
--- /dev/null
+++ b/docs/source/en/model_doc/paddleocr_vl.md
@@ -0,0 +1,248 @@
+
+*This model was released on 2025.10.16 and added to Hugging Face Transformers on 2025.12.10*
+
+# PaddleOCR-VL
+
+
+
+## Overview
+
+**Huggingface Hub**: [PaddleOCR-VL](https://huggingface.co/collections/PaddlePaddle/paddleocr-vl) | **Github Repo**: [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
+
+**Official Website**: [Baidu AI Studio](https://aistudio.baidu.com/paddleocr) | **arXiv**: [Technical Report](https://arxiv.org/pdf/2510.14528)
+
+**PaddleOCR-VL** is a SOTA and resource-efficient model tailored for document parsing. Its core component is PaddleOCR-VL-0.9B, a compact yet powerful vision-language model (VLM) that integrates a NaViT-style dynamic resolution visual encoder with the ERNIE-4.5-0.3B language model to enable accurate element recognition. This innovative model efficiently supports 109 languages and excels in recognizing complex elements (e.g., text, tables, formulas, and charts), while maintaining minimal resource consumption. Through comprehensive evaluations on widely used public benchmarks and in-house benchmarks, PaddleOCR-VL achieves SOTA performance in both page-level document parsing and element-level recognition. It significantly outperforms existing solutions, exhibits strong competitiveness against top-tier VLMs, and delivers fast inference speeds. These strengths make it highly suitable for practical deployment in real-world scenarios.
+
+
+

+
+
+### **Core Features**
+
+1. **Compact yet Powerful VLM Architecture:** We present a novel vision-language model that is specifically designed for resource-efficient inference, achieving outstanding performance in element recognition. By integrating a NaViT-style dynamic high-resolution visual encoder with the lightweight ERNIE-4.5-0.3B language model, we significantly enhance the modelβs recognition capabilities and decoding efficiency. This integration maintains high accuracy while reducing computational demands, making it well-suited for efficient and practical document processing applications.
+
+2. **SOTA Performance on Document Parsing:** PaddleOCR-VL achieves state-of-the-art performance in both page-level document parsing and element-level recognition. It significantly outperforms existing pipeline-based solutions and exhibiting strong competitiveness against leading vision-language models (VLMs) in document parsing. Moreover, it excels in recognizing complex document elements, such as text, tables, formulas, and charts, making it suitable for a wide range of challenging content types, including handwritten text and historical documents. This makes it highly versatile and suitable for a wide range of document types and scenarios.
+
+3. **Multilingual Support:** PaddleOCR-VL Supports 109 languages, covering major global languages, including but not limited to Chinese, English, Japanese, Latin, and Korean, as well as languages with different scripts and structures, such as Russian (Cyrillic script), Arabic, Hindi (Devanagari script), and Thai. This broad language coverage substantially enhances the applicability of our system to multilingual and globalized document processing scenarios.
+
+### **Model Architecture**
+
+
+

+
+
+## Usage
+
+### Usage tips
+
+> [!IMPORTANT]
+> We currently recommend using the [PaddleOCR official method for inference](https://www.paddleocr.ai/latest/en/version3.x/pipeline_usage/PaddleOCR-VL.html), as it is faster and supports page-level document parsing.
+> The example code below only supports element-level recognition.
+
+We have four types of element-level recognition:
+
+- Text recognition, indicated by the prompt `OCR:`.
+- Formula recognition, indicated by the prompt `Formula Recognition:`.
+- Table recognition, indicated by the prompt `Table Recognition:`.
+- Chart recognition, indicated by the prompt `Chart Recognition:`.
+
+The following examples are all based on text recognition, with the prompt `OCR:`.
+
+### Single input inference
+
+The example below demonstrates how to generate text with PaddleOCRVL using [`Pipeline`] or the [`AutoModel`].
+
+
+
+
+```py
+from transformers import pipeline
+
+pipe = pipeline("image-text-to-text", model="PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
+messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
+ {"type": "text", "text": "OCR:"},
+ ]
+ }
+]
+result = pipe(text=messages)
+print(result[0]["generated_text"])
+```
+
+
+
+
+
+```py
+from transformers import AutoProcessor, AutoModelForImageTextToText
+
+model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
+processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
+messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
+ {"type": "text", "text": "OCR:"},
+ ]
+ }
+]
+inputs = processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+).to(model.device)
+
+outputs = model.generate(**inputs, max_new_tokens=100)
+result = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:-1])
+print(result)
+```
+
+
+
+
+### Batched inference
+
+PaddleOCRVL also supports batched inference. We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Here is how you can do it with PaddleOCRVL using [`Pipeline`] or the [`AutoModel`]:
+
+
+
+
+```py
+from transformers import pipeline
+
+pipe = pipeline("image-text-to-text", model="PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
+messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
+ {"type": "text", "text": "OCR:"},
+ ]
+ }
+]
+result = pipe(text=[messages, messages])
+print(result[0][0]["generated_text"])
+print(result[1][0]["generated_text"])
+```
+
+
+
+
+
+```py
+from transformers import AutoProcessor, AutoModelForImageTextToText
+
+model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
+processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
+messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
+ {"type": "text", "text": "OCR:"},
+ ]
+ }
+]
+batch_messages = [messages, messages]
+inputs = processor.apply_chat_template(
+ batch_messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding=True,
+ padding_side='left',
+).to(model.device)
+
+generated_ids = model.generate(**inputs, max_new_tokens=100)
+generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
+result = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
+print(result)
+```
+
+
+
+
+### Using Flash Attention 2
+
+Flash Attention 2 is an even faster, optimized version of the previous optimization, please refer to the [FlashAttention](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention).
+
+For example:
+
+```shell
+pip install flash-attn --no-build-isolation
+```
+
+```python
+from transformers import AutoModelForImageTextToText
+model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16", attn_implementation="flash_attention_2")
+```
+
+## PaddleOCRVLForConditionalGeneration
+
+[[autodoc]] PaddleOCRVLForConditionalGeneration
+ - forward
+
+## PaddleOCRVLConfig
+
+[[autodoc]] PaddleOCRVLConfig
+
+## PaddleOCRVisionConfig
+
+[[autodoc]] PaddleOCRVisionConfig
+
+## PaddleOCRTextConfig
+
+[[autodoc]] PaddleOCRTextConfig
+
+## PaddleOCRTextModel
+
+[[autodoc]] PaddleOCRTextModel
+
+## PaddleOCRVisionModel
+
+[[autodoc]] PaddleOCRVisionModel
+
+## PaddleOCRVLImageProcessor
+
+[[autodoc]] PaddleOCRVLImageProcessor
+
+## PaddleOCRVLImageProcessorFast
+
+[[autodoc]] PaddleOCRVLImageProcessorFast
+
+## PaddleOCRVLModel
+
+[[autodoc]] PaddleOCRVLModel
+
+## PaddleOCRVLProcessor
+
+[[autodoc]] PaddleOCRVLProcessor
+
+## PaddleOCRVisionTransformer
+
+[[autodoc]] PaddleOCRVisionTransformer
diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py
index 24eab78c14fc..acb99b99f83e 100644
--- a/src/transformers/conversion_mapping.py
+++ b/src/transformers/conversion_mapping.py
@@ -213,6 +213,7 @@ def get_checkpoint_conversion_mapping(model_type):
"sam3",
"sam3_tracker",
"sam3_tracker_video",
+ "paddleocrvl",
]
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 71b2155e9bc5..c25d1f2d2987 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -265,6 +265,7 @@
from .ovis2 import *
from .owlv2 import *
from .owlvit import *
+ from .paddleocr_vl import *
from .paligemma import *
from .parakeet import *
from .patchtsmixer import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 38a0abb9e2d7..281bb0e773f6 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -300,6 +300,7 @@
("ovis2", "Ovis2Config"),
("owlv2", "Owlv2Config"),
("owlvit", "OwlViTConfig"),
+ ("paddleocr_vl", "PaddleOCRVLConfig"),
("paligemma", "PaliGemmaConfig"),
("parakeet_ctc", "ParakeetCTCConfig"),
("parakeet_encoder", "ParakeetEncoderConfig"),
@@ -754,6 +755,7 @@
("ovis2", "Ovis2"),
("owlv2", "OWLv2"),
("owlvit", "OWL-ViT"),
+ ("paddleocr_vl", "PaddleOCRVL"),
("paligemma", "PaliGemma"),
("parakeet", "Parakeet"),
("parakeet_ctc", "Parakeet"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 88949a23b2d0..50bd35a7199e 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -153,6 +153,7 @@
("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")),
("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")),
("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
+ ("paddleocr_vl", ("PaddleOCRVLImageProcessor", "PaddleOCRVLImageProcessorFast")),
("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")),
("perception_lm", (None, "PerceptionLMImageProcessorFast")),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index ddd29ad96d5b..7799b11674de 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -1026,6 +1026,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("mistral3", "Mistral3ForConditionalGeneration"),
("mllama", "MllamaForConditionalGeneration"),
("ovis2", "Ovis2ForConditionalGeneration"),
+ ("paddleocr_vl", "PaddleOCRVLForConditionalGeneration"),
("paligemma", "PaliGemmaForConditionalGeneration"),
("perception_lm", "PerceptionLMForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"),
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index 6d08bf37ebab..7f8509c64770 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -114,6 +114,7 @@
("ovis2", "Ovis2Processor"),
("owlv2", "Owlv2Processor"),
("owlvit", "OwlViTProcessor"),
+ ("paddleocr_vl", "PaddleOCRVLProcessor"),
("paligemma", "PaliGemmaProcessor"),
("perception_lm", "PerceptionLMProcessor"),
("phi4_multimodal", "Phi4MultimodalProcessor"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index 1eaf5bad9202..bd6ef8f85785 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -273,6 +273,7 @@
("ovis2", "Qwen2TokenizerFast" if is_tokenizers_available() else None),
("owlv2", "CLIPTokenizerFast" if is_tokenizers_available() else None),
("owlvit", "CLIPTokenizerFast" if is_tokenizers_available() else None),
+ ("paddleocr_vl", "TokenizersBackend" if is_tokenizers_available() else None),
("paligemma", "LlamaTokenizer" if is_tokenizers_available() else None),
("pegasus", "PegasusTokenizer" if is_tokenizers_available() else None),
("pegasus_x", "PegasusTokenizer" if is_tokenizers_available() else None),
diff --git a/src/transformers/models/paddleocr_vl/__init__.py b/src/transformers/models/paddleocr_vl/__init__.py
new file mode 100644
index 000000000000..d6be277b3f4d
--- /dev/null
+++ b/src/transformers/models/paddleocr_vl/__init__.py
@@ -0,0 +1,32 @@
+# coding=utf-8
+# Copyright 2025 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_paddleocr_vl import *
+ from .image_processing_paddleocr_vl import *
+ from .image_processing_paddleocr_vl_fast import *
+ from .modeling_paddleocr_vl import *
+ from .processing_paddleocr_vl import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py
new file mode 100644
index 000000000000..21e4db0d268c
--- /dev/null
+++ b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py
@@ -0,0 +1,336 @@
+# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
+# This file was automatically generated from src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_paddleocr_vl.py file directly. One of our CI enforces this.
+# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
+# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Optional
+
+from ...configuration_utils import PreTrainedConfig
+from ...modeling_rope_utils import RopeParameters
+
+
+class PaddleOCRVisionConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PaddleOCRVisionModel`]. It is used to instantiate a
+ PaddleOCRVL vision encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the PaddleOCRVL
+ [PaddlePaddle/PaddleOCRVL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL) architecture.
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 1152):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 4304):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 27):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input images.
+ image_size (`int`, *optional*, defaults to 384):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ spatial_merge_size (`int`, *optional*, defaults to 2):
+ The size used for merging spatial dimensions.
+
+ Example:
+
+ ```python
+ >>> from transformers import PaddleOCRVisionConfig, PaddleOCRVisionModel
+
+ >>> # Initializing a PaddleOCRVisionConfig with PaddlePaddle/PaddleOCR-VL style configuration
+ >>> configuration = PaddleOCRVisionConfig()
+
+ >>> # Initializing a PaddleOCRVisionModel (with random weights) from the PaddlePaddle/PaddleOCR-VL style configuration
+ >>> model = PaddleOCRVisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "paddleocr_vl_vision"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size=1152,
+ intermediate_size=4304,
+ num_hidden_layers=27,
+ num_attention_heads=16,
+ num_channels=3,
+ image_size=384,
+ patch_size=14,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ spatial_merge_size=2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.spatial_merge_size = spatial_merge_size
+
+
+class PaddleOCRTextConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PaddleOCRTextModel`]. It is used to instantiate an Ernie 4.5
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Ernie 4.5 0.3B.
+ e.g. [baidu/ERNIE-4.5-0.3B-PT](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT)
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 103424):
+ Vocabulary size of the Ernie 4.5 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`PaddleOCRTextModel`]
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 18):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 2):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 131072):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
+ Whether to tie weight embeddings
+ rope_parameters (`RopeParameters`, *optional*):
+ Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
+ a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
+ with longer `max_position_embeddings`.
+ use_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in any of the projections including mlp and attention for example.
+ head_dim (`int`, *optional*, defaults to 128):
+ The attention head dimension. If None, it will default to hidden_size // num_attention_heads
+
+ ```python
+ >>> from transformers import PaddleOCRTextModel, PaddleOCRTextConfig
+
+ >>> # Initializing a PaddleOCRText 0.3B style configuration
+ >>> configuration = PaddleOCRTextConfig()
+
+ >>> # Initializing a model from the 0.3B style configuration
+ >>> model = PaddleOCRTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "paddleocr_vl_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ default_theta = 500000.0
+ # Default tensor parallel plan for base model `PaddleOCRTextModel`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size: Optional[int] = 103424,
+ hidden_size: Optional[int] = 1024,
+ intermediate_size: Optional[int] = 3072,
+ num_hidden_layers: Optional[int] = 18,
+ num_attention_heads: Optional[int] = 16,
+ num_key_value_heads: Optional[int] = 2,
+ hidden_act: Optional[str] = "silu",
+ max_position_embeddings: Optional[int] = 131072,
+ initializer_range: Optional[float] = 0.02,
+ rms_norm_eps: Optional[int] = 1e-05,
+ use_cache: Optional[int] = True,
+ pad_token_id: Optional[int] = 0,
+ bos_token_id: Optional[int] = 1,
+ eos_token_id: Optional[int] = 2,
+ tie_word_embeddings: Optional[bool] = True,
+ rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
+ use_bias: Optional[bool] = False,
+ head_dim: Optional[int] = 128,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.use_bias = use_bias
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+ self.rope_parameters = rope_parameters
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class PaddleOCRVLConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PaddleOCRVLForConditionalGeneration`]. It is used to instantiate a
+ PaddleOCRVL model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ PaddleOCRVL [PaddlePaddle/PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL).
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PaddleOCRTextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PaddleOCRVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 100295):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 100296):
+ The video token index to encode the image prompt.
+ vision_start_token_id (`int`, *optional*, defaults to 101305):
+ The token index to denote start of vision input.
+ vision_end_token_id (`int`, *optional*, defaults to 101306):
+ The token index to denote end of vision input.
+
+ ```python
+ >>> from transformers import PaddleOCRVLForConditionalGeneration, PaddleOCRVLConfig
+
+ >>> # Initializing a PaddleOCRVL style configuration
+ >>> configuration = PaddleOCRVLConfig()
+
+ >>> # Initializing a model from the PaddleOCRVL style configuration
+ >>> model = PaddleOCRVLForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "paddleocr_vl"
+
+ sub_configs = {"vision_config": PaddleOCRVisionConfig, "text_config": PaddleOCRTextConfig}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ image_token_id=100295,
+ video_token_id=100296,
+ vision_start_token_id=101305,
+ vision_end_token_id=101306,
+ **kwargs,
+ ):
+ if isinstance(vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+
+ if isinstance(text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ # Hub configs are saved as flat dicts so we pop some of kwargs to init `TextConfig`
+ text_params = inspect.signature(self.sub_configs["text_config"].__init__).parameters.keys()
+ text_params = list(text_params) + ["rope_scaling", "rope_theta"]
+ text_config = {key: kwargs.pop(key) for key in text_params if key in kwargs}
+ text_config["dtype"] = kwargs.get("torch_dtype", kwargs.get("dtype")) # don't pop the dtype
+ self.text_config = self.sub_configs["text_config"](**text_config)
+
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+
+ # FIXME: arthur/cyril - tying has to be used from the text config
+ kwargs["tie_word_embeddings"] = self.text_config.tie_word_embeddings
+ super().__init__(**kwargs)
+
+
+__all__ = ["PaddleOCRVLConfig", "PaddleOCRVisionConfig", "PaddleOCRTextConfig"]
diff --git a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py
new file mode 100644
index 000000000000..101668ea0335
--- /dev/null
+++ b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py
@@ -0,0 +1,503 @@
+# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
+# This file was automatically generated from src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_paddleocr_vl.py file directly. One of our CI enforces this.
+# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
+# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature
+from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...processing_utils import ImagesKwargs
+from ...utils import TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class PaddleOCRVLImageProcessorKwargs(ImagesKwargs, total=False):
+ r"""
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
+ The min pixels of the image to resize the image.
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
+ The max pixels of the image to resize the image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to 2):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to 2):
+ The merge size of the vision encoder to llm encoder.
+ """
+
+ min_pixels: int
+ max_pixels: int
+ patch_size: int
+ temporal_patch_size: int
+ merge_size: int
+
+
+def smart_resize(
+ height: int,
+ width: int,
+ factor: int = 28,
+ min_pixels: int = 384 * 384,
+ max_pixels: int = 1536 * 1536,
+):
+ if height < factor:
+ width = round((width * factor) / height)
+ height = factor
+
+ if width < factor:
+ height = round((height * factor) / width)
+ width = factor
+
+ if max(height, width) / min(height, width) > 200:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
+ )
+ h_bar = round(height / factor) * factor
+ w_bar = round(width / factor) * factor
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = math.floor(height / beta / factor) * factor
+ w_bar = math.floor(width / beta / factor) * factor
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = math.ceil(height * beta / factor) * factor
+ w_bar = math.ceil(width * beta / factor) * factor
+ return h_bar, w_bar
+
+
+class PaddleOCRVLImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a PaddleOCRVL image processor that dynamically resizes images based on the original images.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions.
+ size (`dict[str, int]`, *optional*):
+ Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use when resizing the image.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*):
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
+ image_std (`float` or `list[float]`, *optional*):
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ min_pixels (`int`, *optional*, defaults to `384 * 384`):
+ The min pixels of the image to resize the image.
+ max_pixels (`int`, *optional*, defaults to `1536 * 1536`):
+ The max pixels of the image to resize the image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to 1):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to 2):
+ The merge size of the vision encoder to llm encoder.
+ """
+
+ model_input_names = [
+ "pixel_values",
+ "image_grid_thw",
+ ]
+ valid_kwargs = PaddleOCRVLImageProcessorKwargs
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: bool = True,
+ min_pixels: int = 384 * 384,
+ max_pixels: int = 1536 * 1536,
+ patch_size: int = 14,
+ temporal_patch_size: int = 1,
+ merge_size: int = 2,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
+ raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
+ else:
+ size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
+ # backward compatibility: override size with min_pixels and max_pixels if they are provided
+ if min_pixels is not None:
+ size["shortest_edge"] = min_pixels
+ if max_pixels is not None:
+ size["longest_edge"] = max_pixels
+ self.min_pixels = size["shortest_edge"]
+ self.max_pixels = size["longest_edge"]
+ self.size = size
+
+ self.do_resize = do_resize
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.merge_size = merge_size
+ self.do_convert_rgb = do_convert_rgb
+
+ def _preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ patch_size: Optional[int] = None,
+ temporal_patch_size: Optional[int] = None,
+ merge_size: Optional[int] = None,
+ do_convert_rgb: Optional[bool] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
+ Args:
+ images (`ImageInput`):
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Scale factor to use if rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to `self.merge_size`):
+ The merge size of the vision encoder to llm encoder.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ images = make_list_of_images(images)
+ images = self.fetch_images(images)
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if is_scaled_image(images[0]) and do_rescale:
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
+ resized_height, resized_width = height, width
+ processed_images = []
+
+ for image in images:
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=patch_size * merge_size,
+ min_pixels=size["shortest_edge"],
+ max_pixels=size["longest_edge"],
+ )
+ image = resize(
+ image,
+ size=(resized_height, resized_width),
+ resample=resample,
+ input_data_format=input_data_format,
+ )
+
+ if do_rescale:
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image,
+ mean=image_mean,
+ std=image_std,
+ input_data_format=input_data_format,
+ )
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ processed_images.append(image)
+
+ patches = np.array(processed_images)
+ if data_format == ChannelDimension.LAST:
+ patches = patches.transpose(0, 3, 1, 2)
+ if patches.shape[0] == 1:
+ patches = np.tile(patches, (temporal_patch_size, 1, 1, 1))
+
+ channel = patches.shape[1]
+ grid_t = patches.shape[0] // temporal_patch_size
+ grid_h, grid_w = (
+ resized_height // patch_size,
+ resized_width // patch_size,
+ )
+ patches = patches.reshape(
+ grid_t,
+ temporal_patch_size,
+ channel,
+ grid_h,
+ patch_size,
+ grid_w,
+ patch_size,
+ )
+ patches = patches.transpose(0, 3, 5, 2, 1, 4, 6)
+ if temporal_patch_size != 1:
+ raise ValueError(f"temporal_patch_size must be 1!, but got {temporal_patch_size}!")
+ flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel, patch_size, patch_size)
+ return flatten_patches, (grid_t, grid_h, grid_w)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ min_pixels: Optional[int] = None,
+ max_pixels: Optional[int] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ patch_size: Optional[int] = None,
+ temporal_patch_size: Optional[int] = None,
+ merge_size: Optional[int] = None,
+ do_convert_rgb: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ min_pixels (`int`, *optional*, defaults to `self.min_pixels`):
+ The min pixels of the image to resize the image.
+ max_pixels (`int`, *optional*, defaults to `self.max_pixels`):
+ The max pixels of the image to resize the image.
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to `self.merge_size`):
+ The merge size of the vision encoder to llm encoder.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ """
+ min_pixels = min_pixels if min_pixels is not None else self.min_pixels
+ max_pixels = max_pixels if max_pixels is not None else self.max_pixels
+
+ if size is not None:
+ if "shortest_edge" not in size or "longest_edge" not in size:
+ raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
+ min_pixels = size["shortest_edge"]
+ elif min_pixels is not None and max_pixels is not None:
+ # backward compatibility: override size with min_pixels and max_pixels if they are provided
+ size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
+ else:
+ size = {**self.size}
+
+ do_resize = do_resize if do_resize is not None else self.do_resize
+
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self.temporal_patch_size
+ merge_size = merge_size if merge_size is not None else self.merge_size
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ if images is not None:
+ images = self.fetch_images(images)
+ images = make_flat_list_of_images(images)
+
+ if images is not None and not valid_images(images):
+ raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor")
+
+ validate_preprocess_arguments(
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ data = {}
+ pixel_values, vision_grid_thws = [], []
+ for image in images:
+ patches, image_grid_thw = self._preprocess(
+ image,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ patch_size=patch_size,
+ temporal_patch_size=temporal_patch_size,
+ merge_size=merge_size,
+ data_format=data_format,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ )
+ pixel_values.extend(patches)
+ vision_grid_thws.append(image_grid_thw)
+ pixel_values = np.array(pixel_values)
+ vision_grid_thws = np.array(vision_grid_thws)
+ data.update({"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws})
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
+ """
+ A utility that returns number of image patches for a given image size.
+
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ images_kwargs (`dict`, *optional*)
+ Any kwargs to override defaults of the image processor.
+ Returns:
+ `int`: Number of image patches per image.
+ """
+ min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"]
+ max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"]
+ patch_size = images_kwargs.get("patch_size", self.patch_size)
+ merge_size = images_kwargs.get("merge_size", self.merge_size)
+
+ factor = patch_size * merge_size
+ resized_height, resized_width = smart_resize(
+ height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
+ )
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+ return grid_h * grid_w
+
+
+__all__ = ["PaddleOCRVLImageProcessor"]
diff --git a/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py
new file mode 100644
index 000000000000..cf6a65889b97
--- /dev/null
+++ b/src/transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py
@@ -0,0 +1,209 @@
+# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
+# This file was automatically generated from src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_paddleocr_vl.py file directly. One of our CI enforces this.
+# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
+# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Optional, Union
+
+import torch
+import torch.nn.functional as F
+
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images
+from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling, SizeDict
+from ...utils import TensorType
+
+
+def smart_resize(
+ height: int,
+ width: int,
+ factor: int = 28,
+ min_pixels: int = 384 * 384,
+ max_pixels: int = 1536 * 1536,
+):
+ if height < factor:
+ width = round((width * factor) / height)
+ height = factor
+
+ if width < factor:
+ height = round((height * factor) / width)
+ width = factor
+
+ if max(height, width) / min(height, width) > 200:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
+ )
+ h_bar = round(height / factor) * factor
+ w_bar = round(width / factor) * factor
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = math.floor(height / beta / factor) * factor
+ w_bar = math.floor(width / beta / factor) * factor
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = math.ceil(height * beta / factor) * factor
+ w_bar = math.ceil(width * beta / factor) * factor
+ return h_bar, w_bar
+
+
+class PaddleOCRVLImageProcessorFast(BaseImageProcessorFast):
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: bool = True,
+ min_pixels: int = 384 * 384,
+ max_pixels: int = 1536 * 1536,
+ patch_size: int = 14,
+ temporal_patch_size: int = 1,
+ merge_size: int = 2,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
+ raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
+ else:
+ size = {"shortest_edge": 384 * 384, "longest_edge": 1536 * 1536}
+ # backward compatibility: override size with min_pixels and max_pixels if they are provided
+ if min_pixels is not None:
+ size["shortest_edge"] = min_pixels
+ if max_pixels is not None:
+ size["longest_edge"] = max_pixels
+ self.min_pixels = size["shortest_edge"]
+ self.max_pixels = size["longest_edge"]
+ self.size = size
+
+ self.do_resize = do_resize
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.merge_size = merge_size
+ self.do_convert_rgb = do_convert_rgb
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ patch_size: Optional[int] = None,
+ temporal_patch_size: Optional[int] = None,
+ merge_size: Optional[int] = None,
+ **kwargs,
+ ):
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self.temporal_patch_size
+ merge_size = merge_size if merge_size is not None else self.merge_size
+
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ height, width = stacked_images.shape[-2:]
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=patch_size * merge_size,
+ min_pixels=size["shortest_edge"],
+ max_pixels=size["longest_edge"],
+ )
+ stacked_images = self.resize(
+ image=stacked_images,
+ size=SizeDict(height=resized_height, width=resized_width),
+ interpolation=interpolation,
+ )
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ processed_grids = {}
+ for shape, stacked_images in grouped_images.items():
+ resized_height, resized_width = stacked_images.shape[-2:]
+ # Fused rescale and normalize
+ patches = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+
+ if patches.ndim == 4:
+ # add a temporal dimension if we have images
+ patches = patches.unsqueeze(1)
+ if patches.shape[1] % temporal_patch_size != 0:
+ repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1)
+ patches = torch.cat([patches, repeats], dim=1)
+
+ batch_size, grid_t, channel = patches.shape[:3]
+ grid_t = grid_t // temporal_patch_size
+ grid_h, grid_w = (
+ resized_height // patch_size,
+ resized_width // patch_size,
+ )
+ patches = patches.view(
+ batch_size,
+ grid_t,
+ temporal_patch_size,
+ channel,
+ grid_h,
+ patch_size,
+ grid_w,
+ patch_size,
+ )
+ patches = patches.permute(0, 1, 4, 6, 3, 2, 5, 7)
+ flatten_patches = patches.reshape(batch_size, grid_t * grid_h * grid_w, channel, patch_size, patch_size)
+
+ processed_images_grouped[shape] = flatten_patches
+ processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_grids = reorder_images(processed_grids, grouped_images_index)
+ pixel_values = torch.cat(processed_images, dim=0)
+ image_grid_thw = torch.tensor(processed_grids)
+
+ return BatchFeature(
+ data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors
+ )
+
+
+__all__ = ["PaddleOCRVLImageProcessorFast"]
diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py
new file mode 100644
index 000000000000..ebd48e8d9c69
--- /dev/null
+++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py
@@ -0,0 +1,1668 @@
+# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
+# This file was automatically generated from src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_paddleocr_vl.py file directly. One of our CI enforces this.
+# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
+# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN, GELUActivation
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_bidirectional_mask, create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
+from ...utils.generic import check_model_inputs, maybe_autocast
+from .configuration_paddleocr_vl import PaddleOCRTextConfig, PaddleOCRVisionConfig, PaddleOCRVLConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class PaddleOCRProjector(nn.Module):
+ def __init__(self, config: PaddleOCRVLConfig):
+ super().__init__()
+ self.merge_kernel_size = (config.vision_config.spatial_merge_size, config.vision_config.spatial_merge_size)
+
+ hidden_size = config.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1]
+
+ self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-05)
+ self.linear_1 = nn.Linear(hidden_size, hidden_size, bias=True)
+ self.act = GELUActivation()
+ self.linear_2 = nn.Linear(hidden_size, config.text_config.hidden_size, bias=True)
+
+ def forward(self, image_features: torch.Tensor, image_grid_thw: torch.Tensor) -> torch.Tensor:
+ image_features_chunks = image_features.split(image_grid_thw.prod(dim=1).tolist(), dim=0)
+ m1, m2 = self.merge_kernel_size
+
+ processed_features = []
+ for image_feature, image_grid in zip(image_features_chunks, image_grid_thw):
+ image_feature = self.pre_norm(image_feature)
+ t, h, w = image_grid
+ d = image_feature.shape[-1]
+ h_block = h // m1
+ w_block = w // m2
+
+ image_feature = image_feature.reshape(t, h_block, m1, w_block, m2, d)
+ image_feature = image_feature.transpose(2, 3)
+ image_feature = image_feature.reshape(t * h_block * w_block, m1 * m2 * d)
+
+ hidden_states = self.linear_1(image_feature)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ processed_features.append(hidden_states)
+
+ return torch.cat(processed_features, dim=0)
+
+
+class PaddleOCRVisionRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+
+class PaddleOCRRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: PaddleOCRVLConfig, device=None):
+ super().__init__()
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+
+ self.rope_type = self.config.rope_parameters["rope_type"]
+ rope_init_fn: Callable = self.compute_default_rope_parameters
+ if self.rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
+
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = inv_freq
+
+ @staticmethod
+ def compute_default_rope_parameters(
+ config: Optional[PaddleOCRVLConfig] = None,
+ device: Optional["torch.device"] = None,
+ seq_len: Optional[int] = None,
+ ) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies according to the original RoPE implementation
+ Args:
+ config ([`~transformers.PreTrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ base = config.rope_parameters["rope_theta"]
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # Compute the inverse frequencies
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
+ )
+ return inv_freq, attention_factor
+
+ # Ignore copy
+ def forward(self, x, position_ids):
+ # In contrast to other models, PaddleOCR has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class PaddleOCRMLP(nn.Module):
+ def __init__(self, config: PaddleOCRTextConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
+
+ Explanation:
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
+ difference with modern LLMs.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
+ mrope_section(`List(int)`):
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ mrope_section = mrope_section * 2
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class PaddleOCRAttention(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
+
+ def __init__(self, config: PaddleOCRVLConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.is_causal = True
+
+ self.attention_dropout = 0.0
+ self.rope_parameters = config.rope_parameters
+ self.scaling = self.head_dim**-0.5
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias)
+ self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
+ self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
+ query_states, key_states, cos, sin, self.config.rope_parameters["mrope_section"]
+ )
+
+ if past_key_values is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ position_ids=position_ids, # pass positions for FA2
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class PaddleOCRRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ PaddleOCRRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class PaddleOCRDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: PaddleOCRTextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = PaddleOCRAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = PaddleOCRMLP(config)
+ self.input_layernorm = PaddleOCRRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = PaddleOCRRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class PaddleOCRVLPreTrainedModel(PreTrainedModel):
+ config: PaddleOCRVLConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["PaddleOCRDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+
+ _can_record_outputs = {
+ "hidden_states": PaddleOCRDecoderLayer,
+ "attentions": PaddleOCRAttention,
+ }
+
+
+@auto_docstring
+class PaddleOCRTextModel(PaddleOCRVLPreTrainedModel):
+ def __init__(self, config: PaddleOCRTextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [PaddleOCRDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = PaddleOCRRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = PaddleOCRRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = (
+ torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
+ text_position_ids = position_ids[0]
+ position_ids = position_ids[1:]
+ else:
+ text_position_ids = None
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=text_position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_embeddings=position_embeddings,
+ position_ids=text_position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+class PaddleOCRVisionModel(PaddleOCRVLPreTrainedModel):
+ config: PaddleOCRVisionConfig
+ main_input_name = "pixel_values"
+ input_modalities = "image"
+
+ def __init__(self, config: PaddleOCRVisionConfig):
+ super().__init__(config)
+
+ self.vision_model = PaddleOCRVisionTransformer(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ cu_seqlens: torch.Tensor,
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
+ **kwargs,
+ ) -> BaseModelOutputWithPooling:
+ """
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`):
+ The tensors corresponding to the input images.
+ cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`):
+ The cumulative sequence lengths of each image or video feature.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ return self.vision_model(
+ pixel_values=pixel_values,
+ cu_seqlens=cu_seqlens,
+ image_grid_thw=image_grid_thw,
+ )
+
+
+class PaddleOCRVisionEmbeddings(nn.Module):
+ def __init__(self, config: PaddleOCRVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing and no class embeddings.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+ num_positions = self.position_embedding.weight.shape[0]
+
+ patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
+
+ dim = embeddings.shape[-1]
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(height, width),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return patch_pos_embed
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ batch_size, squence_len, channel, height, width = pixel_values.shape
+ target_dtype = self.patch_embedding.weight.dtype
+ pixel_values = pixel_values.reshape(batch_size * squence_len, channel, height, width)
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
+ embeddings = patch_embeds.flatten(-2).squeeze(-1)
+ embeddings = embeddings.reshape(batch_size, squence_len, -1)
+
+ start = 0
+ embeddings = embeddings.squeeze(0)
+ tmp_embeddings = []
+ for image_grid in image_grid_thw:
+ t, h, w = image_grid
+ end = start + t * h * w
+ image_embeddings = embeddings[start:end, :]
+ position_embedding = self.interpolate_pos_encoding(image_embeddings, h, w).squeeze(0).repeat(t, 1)
+ image_embeddings = image_embeddings + position_embedding
+ tmp_embeddings.append(image_embeddings)
+ start = end
+ embeddings = torch.concat(tmp_embeddings, dim=0)
+
+ return embeddings
+
+
+def apply_rotary_pos_emb_vision(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
+ orig_q_dtype = q.dtype
+ orig_k_dtype = k.dtype
+ q, k = q.float(), k.float()
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ q_embed = q_embed.to(orig_q_dtype)
+ k_embed = k_embed.to(orig_k_dtype)
+ return q_embed, k_embed
+
+
+class PaddleOCRVisionAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: PaddleOCRVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.is_causal = False
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.num_key_value_groups = 1
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input to the layer of shape `(seq_len, embed_dim)`.
+ cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`):
+ The cumulative sequence lengths of each image or video feature.
+ position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`):
+ The cosine and sine position embeddings for vision attention.
+ """
+ seq_length = hidden_states.shape[0]
+ query_states = self.q_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim)
+ key_states = self.k_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim)
+ value_states = self.v_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
+
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ if self.config._attn_implementation == "flash_attention_2":
+ # Flash Attention 2: Use cu_seqlens for variable length attention
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ cu_seq_lens_q=cu_seqlens,
+ cu_seq_lens_k=cu_seqlens,
+ max_length_q=max_seqlen,
+ max_length_k=max_seqlen,
+ is_causal=False,
+ **kwargs,
+ )
+ else:
+ # Other implementations: Process each chunk separately
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+ splits = [
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
+ ]
+
+ attn_outputs, attn_weights = [], []
+ for q, k, v in zip(*splits):
+ attn_output, attn_weight = attention_interface(
+ self,
+ q,
+ k,
+ v,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ is_causal=False,
+ **kwargs,
+ )
+ attn_outputs.append(attn_output)
+ attn_weights.append(attn_weight)
+
+ attn_output = torch.cat(attn_outputs, dim=1)
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class PaddleOCRVisionMLP(nn.Module):
+ def __init__(self, config: PaddleOCRVisionConfig):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class PaddleOCRVisionEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: PaddleOCRVisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.self_attn = PaddleOCRVisionAttention(config=config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = PaddleOCRVisionMLP(config=config)
+
+ @auto_docstring
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ r"""
+ cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`):
+ The cumulative sequence lengths of each image or video feature.
+ position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`):
+ The cosine and sine position embeddings for vision attention.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, _ = self.self_attn(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class PaddleOCRVisionEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`PaddleOCRVisionEncoderLayer`].
+
+ Args:
+ config: PaddleOCRVisionConfig
+ """
+
+ def __init__(self, config: PaddleOCRVisionConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([PaddleOCRVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+ embed_dim = config.hidden_size
+ num_heads = config.num_attention_heads
+ head_dim = embed_dim // num_heads
+ self.rotary_pos_emb = PaddleOCRVisionRotaryEmbedding(head_dim // 2)
+
+ # Ignore copy
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ inputs_embeds: torch.FloatTensor,
+ cu_seqlens: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
+ ) -> BaseModelOutput:
+ """
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`):
+ The cumulative sequence lengths of each image or video feature.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ The attention_mask used in forward function shape [batch_size X sequence_length] if not None.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ device = inputs_embeds.device
+ hidden_states = inputs_embeds
+ attention_mask = create_bidirectional_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ )
+ split_hids = []
+ split_wids = []
+ for t, h, w in image_grid_thw:
+ image_pids = torch.arange(t * h * w, device=device) % (h * w)
+ sample_hids = image_pids // w
+ sample_wids = image_pids % w
+ split_hids.append(sample_hids)
+ split_wids.append(sample_wids)
+ width_position_ids = torch.concat(split_wids, dim=0)
+ height_position_ids = torch.concat(split_hids, dim=0)
+
+ pids = torch.stack([height_position_ids, width_position_ids], dim=-1)
+ max_grid_size = pids.max() + 1
+ rotary_embeddings_max_grid = self.rotary_pos_emb(max_grid_size)
+ rotary_embeddings = rotary_embeddings_max_grid[pids].flatten(1)
+ rotary_embeddings = rotary_embeddings.repeat(1, 2)
+ position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin())
+
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ position_embeddings=position_embeddings,
+ )
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ )
+
+
+class PaddleOCRVisionTransformer(PaddleOCRVLPreTrainedModel):
+ def __init__(self, config: PaddleOCRVisionConfig):
+ super().__init__(config)
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = PaddleOCRVisionEmbeddings(config)
+ self.encoder = PaddleOCRVisionEncoder(config)
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ cu_seqlens: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
+ **kwargs,
+ ) -> BaseModelOutputWithPooling:
+ """
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`):
+ The tensors corresponding to the input images.
+ cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`):
+ The cumulative sequence lengths of each image or video feature.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention_mask used in forward function shape [batch_size X sequence_length] if not None.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ hidden_states = self.embeddings(pixel_values, image_grid_thw=image_grid_thw)
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ inputs_embeds=hidden_states,
+ cu_seqlens=cu_seqlens,
+ attention_mask=attention_mask,
+ image_grid_thw=image_grid_thw,
+ )
+
+ last_hidden_state = encoder_outputs.last_hidden_state
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=None,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava outputs, with hidden states and attentions.
+ """
+)
+class PaddleOCRVLModelOutputWithPast(ModelOutput):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for PaddleOCRVL causal language model (or autoregressive) outputs.
+ """
+)
+class PaddleOCRVLCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+@auto_docstring
+class PaddleOCRVLModel(PaddleOCRVLPreTrainedModel):
+ base_model_prefix = "model"
+ _checkpoint_conversion_mapping = {"^model": "language_model"}
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ _keys_to_ignore_on_load_unexpected = ["packing_position_embedding", "vision_model.head"]
+
+ def __init__(self, config: PaddleOCRVLConfig):
+ super().__init__(config)
+ self.visual = PaddleOCRVisionModel._from_config(config.vision_config)
+ self.language_model = PaddleOCRTextModel._from_config(config.text_config)
+ self.rope_deltas = None
+ self.projector = PaddleOCRProjector(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.language_model.embed_tokens = value
+
+ def get_rope_index(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
+
+ Explanation:
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
+
+ For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
+ Examples:
+ input_ids: [T T T T T], here T is for text.
+ temporal position_ids: [0, 1, 2, 3, 4]
+ height position_ids: [0, 1, 2, 3, 4]
+ width position_ids: [0, 1, 2, 3, 4]
+
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
+ and 1D rotary position embedding for text part.
+ Examples:
+ Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
+ vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
+ text temporal position_ids: [3, 4, 5, 6, 7]
+ text height position_ids: [3, 4, 5, 6, 7]
+ text width position_ids: [3, 4, 5, 6, 7]
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ Returns:
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
+ """
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+ mrope_position_deltas = []
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+ total_input_ids = input_ids
+ if attention_mask is None:
+ attention_mask = torch.ones_like(total_input_ids)
+ position_ids = torch.ones(
+ 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
+ )
+ image_index, video_index = 0, 0
+ for i, input_ids in enumerate(total_input_ids):
+ input_ids = input_ids[attention_mask[i].to(input_ids.device) == 1]
+ image_nums, video_nums = 0, 0
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
+ vision_tokens = input_ids[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (vision_tokens == video_token_id).sum()
+ input_tokens = input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+ if ed_image < ed_video:
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+ else:
+ t, h, w = (
+ video_grid_thw[video_index][0],
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+ text_len = ed - st
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+ else:
+ if attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
+ else:
+ position_ids = (
+ torch.arange(input_ids.shape[1], device=input_ids.device)
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+
+ return position_ids, mrope_position_deltas
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ """
+ Encodes videos into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input videos.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
+ split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ video_embeds = torch.split(video_embeds, split_sizes)
+ return video_embeds
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0)
+ cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=image_grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)
+ vision_outputs = self.visual(
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ cu_seqlens=cu_seqlens,
+ )
+ image_embeds = vision_outputs.last_hidden_state
+ image_embeds = self.projector(image_embeds, image_grid_thw)
+ return image_embeds
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ return special_image_mask
+
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple, PaddleOCRVLModelOutputWithPast]:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+ if inputs_embeds is None:
+ inputs_embeds = self.language_model.embed_tokens(input_ids)
+
+ if pixel_values is not None:
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw).to(
+ inputs_embeds.device, inputs_embeds.dtype
+ )
+ image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds)
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if position_ids is None:
+ past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
+ if self.rope_deltas is None or past_key_values_length == 0:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids=input_ids,
+ image_grid_thw=image_grid_thw,
+ attention_mask=attention_mask,
+ )
+ self.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
+ delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ position_ids = position_ids + delta.to(position_ids.device)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ output = PaddleOCRVLModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+
+ return output
+
+
+class PaddleOCRVLForConditionalGeneration(PaddleOCRVLPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^visual": "model.visual",
+ "^mlp_AR": "model.projector",
+ r"^model(?!(\.visual|\.projector|\.language_model))": "model.language_model",
+ }
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+ _keys_to_ignore_on_load_unexpected = ["packing_position_embedding", "vision_model.head"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = PaddleOCRVLModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw)
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ return self.model.get_image_features(pixel_values, image_grid_thw)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, PaddleOCRVLCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, PaddleOCRVLForConditionalGeneration
+
+ >>> model = PaddleOCRVLForConditionalGeneration.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
+ >>> processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "image": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg",
+ },
+ {"type": "text", "text": "OCR:"},
+ ],
+ }
+ ]
+
+ >>> inputs = processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_dict=True,
+ return_tensors="pt"
+ ).to(model.device)
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=1024)
+ >>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
+ >>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ >>> print(output_text)
+ ```
+ """
+ outputs: PaddleOCRVLModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ image_grid_thw=image_grid_thw,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ pixel_values=pixel_values,
+ rope_deltas=rope_deltas,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = outputs.last_hidden_state
+
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return PaddleOCRVLCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=outputs.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ # Qwen2-VL position_ids are prepareed with rope_deltas in forward
+ if position_ids is None:
+ # Calculate RoPE index once per generation in the pre-fill stage only.
+ # When compiling, we can't check tensor values thus we check only input length
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
+ # models currently cannot do asssisted decoding
+ if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None:
+ vision_positions, rope_deltas = self.model.get_rope_index(
+ model_inputs.get("input_ids", None),
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ attention_mask=attention_mask,
+ )
+ self.model.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ elif "position_ids" in model_inputs:
+ batch_size, seq_length = model_inputs["position_ids"].shape
+ device = model_inputs["position_ids"].device
+ position_ids = torch.arange(seq_length, device=device)
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
+ delta = cache_position[0] + self.model.rope_deltas
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ vision_positions = position_ids + delta.expand_as(position_ids)
+
+ # Concatenate "text + vision" positions into [4, bs, seq-len]
+ text_positions = model_inputs["position_ids"][None, ...]
+ model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
+
+ if model_inputs["cache_position"][0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_values_videos"] = None
+
+ return model_inputs
+
+ def _get_image_nums_and_video_nums(
+ self,
+ input_ids: Optional[torch.LongTensor],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Returns:
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
+ """
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+
+ if inputs_embeds is not None:
+ vision_start_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ image_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ video_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ else:
+ vision_start_mask = input_ids == vision_start_token_id
+ image_mask = input_ids == image_token_id
+ video_mask = input_ids == video_token_id
+
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
+
+ return image_nums, video_nums
+
+ def _expand_inputs_for_generation(
+ self,
+ expand_size: int = 1,
+ is_encoder_decoder: bool = False,
+ input_ids: Optional[torch.LongTensor] = None,
+ **model_kwargs,
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
+ # Overwritten -- Support for expanding tensors without a batch size dimension
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
+ # image_grid_thw.shape[0] is sum(num_images for samples)
+
+ if expand_size == 1:
+ return input_ids, model_kwargs
+
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
+
+ def _expand_dict_for_generation_visual(dict_to_expand):
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
+ )
+
+ def _repeat_interleave_samples(x, lengths, repeat_times):
+ samples = torch.split(x, lengths)
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
+ return result
+
+ for key in dict_to_expand:
+ if key == "pixel_values":
+ # split images into samples
+ samples = torch.split(image_grid_thw, list(image_nums))
+ # compute the sequence length of images for each sample
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "image_grid_thw":
+ # get the num of images for each sample
+ lengths = list(image_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "pixel_values_videos":
+ samples = torch.split(video_grid_thw, list(video_nums))
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "video_grid_thw":
+ lengths = list(video_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "second_per_grid_ts":
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
+ )
+ return dict_to_expand
+
+ def _expand_dict_for_generation(dict_to_expand):
+ for key in dict_to_expand:
+ if (
+ key != "cache_position"
+ and dict_to_expand[key] is not None
+ and isinstance(dict_to_expand[key], torch.Tensor)
+ and key not in visual_keys
+ ):
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
+ return dict_to_expand
+
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
+
+ if is_encoder_decoder:
+ if model_kwargs.get("encoder_outputs") is None:
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
+
+ return input_ids, model_kwargs
+
+
+__all__ = [
+ "PaddleOCRVLForConditionalGeneration",
+ "PaddleOCRVLModel",
+ "PaddleOCRVLPreTrainedModel",
+ "PaddleOCRVisionTransformer",
+ "PaddleOCRTextModel",
+ "PaddleOCRVisionModel",
+]
diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py
new file mode 100644
index 000000000000..79e6fc45d8bf
--- /dev/null
+++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py
@@ -0,0 +1,1349 @@
+# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...activations import GELUActivation
+from ...cache_utils import Cache, DynamicCache
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images
+from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_list_of_images,
+ to_numpy_array,
+)
+from ...masking_utils import create_bidirectional_mask, create_causal_mask
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling
+from ...modeling_utils import PreTrainedModel
+from ...models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor
+from ...processing_utils import (
+ ProcessingKwargs,
+ ProcessorMixin,
+ Unpack,
+)
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import TensorType, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
+from ...utils.generic import check_model_inputs
+from ..ernie4_5.configuration_ernie4_5 import Ernie4_5Config
+from ..ernie4_5.modeling_ernie4_5 import (
+ Ernie4_5DecoderLayer,
+ Ernie4_5MLP,
+ Ernie4_5Model,
+ Ernie4_5RMSNorm,
+)
+from ..qwen2_5_omni.modeling_qwen2_5_omni import (
+ Qwen2_5OmniAttention,
+)
+from ..qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig
+from ..qwen2_vl.modeling_qwen2_vl import (
+ Qwen2VLCausalLMOutputWithPast,
+ Qwen2VLForConditionalGeneration,
+ Qwen2VLModel,
+ Qwen2VLModelOutputWithPast,
+ Qwen2VLRotaryEmbedding,
+ VisionRotaryEmbedding,
+)
+from ..siglip.configuration_siglip import SiglipVisionConfig
+from ..siglip.modeling_siglip import (
+ SiglipMLP,
+ SiglipVisionEmbeddings,
+)
+from ..video_llama_3.modeling_video_llama_3 import (
+ VideoLlama3VisionAttention,
+ VideoLlama3VisionEncoder,
+ VideoLlama3VisionEncoderLayer,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+def smart_resize(
+ height: int,
+ width: int,
+ factor: int = 28,
+ min_pixels: int = 384 * 384,
+ max_pixels: int = 1536 * 1536,
+):
+ if height < factor:
+ width = round((width * factor) / height)
+ height = factor
+
+ if width < factor:
+ height = round((height * factor) / width)
+ width = factor
+
+ if max(height, width) / min(height, width) > 200:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
+ )
+ h_bar = round(height / factor) * factor
+ w_bar = round(width / factor) * factor
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = math.floor(height / beta / factor) * factor
+ w_bar = math.floor(width / beta / factor) * factor
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = math.ceil(height * beta / factor) * factor
+ w_bar = math.ceil(width * beta / factor) * factor
+ return h_bar, w_bar
+
+
+class PaddleOCRVLImageProcessor(Qwen2VLImageProcessor):
+ r"""
+ Constructs a PaddleOCRVL image processor that dynamically resizes images based on the original images.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions.
+ size (`dict[str, int]`, *optional*):
+ Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use when resizing the image.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*):
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
+ image_std (`float` or `list[float]`, *optional*):
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ min_pixels (`int`, *optional*, defaults to `384 * 384`):
+ The min pixels of the image to resize the image.
+ max_pixels (`int`, *optional*, defaults to `1536 * 1536`):
+ The max pixels of the image to resize the image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to 1):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to 2):
+ The merge size of the vision encoder to llm encoder.
+ """
+
+ model_input_names = [
+ "pixel_values",
+ "image_grid_thw",
+ ]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: bool = True,
+ min_pixels: int = 384 * 384,
+ max_pixels: int = 1536 * 1536,
+ patch_size: int = 14,
+ temporal_patch_size: int = 1,
+ merge_size: int = 2,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+
+ def _preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ patch_size: Optional[int] = None,
+ temporal_patch_size: Optional[int] = None,
+ merge_size: Optional[int] = None,
+ do_convert_rgb: Optional[bool] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
+ Args:
+ images (`ImageInput`):
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Scale factor to use if rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to `self.merge_size`):
+ The merge size of the vision encoder to llm encoder.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ images = make_list_of_images(images)
+ images = self.fetch_images(images)
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if is_scaled_image(images[0]) and do_rescale:
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
+ resized_height, resized_width = height, width
+ processed_images = []
+
+ for image in images:
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=patch_size * merge_size,
+ min_pixels=size["shortest_edge"],
+ max_pixels=size["longest_edge"],
+ )
+ image = resize(
+ image,
+ size=(resized_height, resized_width),
+ resample=resample,
+ input_data_format=input_data_format,
+ )
+
+ if do_rescale:
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image,
+ mean=image_mean,
+ std=image_std,
+ input_data_format=input_data_format,
+ )
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ processed_images.append(image)
+
+ patches = np.array(processed_images)
+ if data_format == ChannelDimension.LAST:
+ patches = patches.transpose(0, 3, 1, 2)
+ if patches.shape[0] == 1:
+ patches = np.tile(patches, (temporal_patch_size, 1, 1, 1))
+
+ channel = patches.shape[1]
+ grid_t = patches.shape[0] // temporal_patch_size
+ grid_h, grid_w = (
+ resized_height // patch_size,
+ resized_width // patch_size,
+ )
+ patches = patches.reshape(
+ grid_t,
+ temporal_patch_size,
+ channel,
+ grid_h,
+ patch_size,
+ grid_w,
+ patch_size,
+ )
+ patches = patches.transpose(0, 3, 5, 2, 1, 4, 6)
+ if temporal_patch_size != 1:
+ raise ValueError(f"temporal_patch_size must be 1!, but got {temporal_patch_size}!")
+ flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel, patch_size, patch_size)
+ return flatten_patches, (grid_t, grid_h, grid_w)
+
+
+class PaddleOCRVLImageProcessorFast(BaseImageProcessorFast):
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: bool = True,
+ min_pixels: int = 384 * 384,
+ max_pixels: int = 1536 * 1536,
+ patch_size: int = 14,
+ temporal_patch_size: int = 1,
+ merge_size: int = 2,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
+ raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
+ else:
+ size = {"shortest_edge": 384 * 384, "longest_edge": 1536 * 1536}
+ # backward compatibility: override size with min_pixels and max_pixels if they are provided
+ if min_pixels is not None:
+ size["shortest_edge"] = min_pixels
+ if max_pixels is not None:
+ size["longest_edge"] = max_pixels
+ self.min_pixels = size["shortest_edge"]
+ self.max_pixels = size["longest_edge"]
+ self.size = size
+
+ self.do_resize = do_resize
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.merge_size = merge_size
+ self.do_convert_rgb = do_convert_rgb
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ patch_size: Optional[int] = None,
+ temporal_patch_size: Optional[int] = None,
+ merge_size: Optional[int] = None,
+ **kwargs,
+ ):
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self.temporal_patch_size
+ merge_size = merge_size if merge_size is not None else self.merge_size
+
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ height, width = stacked_images.shape[-2:]
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=patch_size * merge_size,
+ min_pixels=size["shortest_edge"],
+ max_pixels=size["longest_edge"],
+ )
+ stacked_images = self.resize(
+ image=stacked_images,
+ size=SizeDict(height=resized_height, width=resized_width),
+ interpolation=interpolation,
+ )
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ processed_grids = {}
+ for shape, stacked_images in grouped_images.items():
+ resized_height, resized_width = stacked_images.shape[-2:]
+ # Fused rescale and normalize
+ patches = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+
+ if patches.ndim == 4:
+ # add a temporal dimension if we have images
+ patches = patches.unsqueeze(1)
+ if patches.shape[1] % temporal_patch_size != 0:
+ repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1)
+ patches = torch.cat([patches, repeats], dim=1)
+
+ batch_size, grid_t, channel = patches.shape[:3]
+ grid_t = grid_t // temporal_patch_size
+ grid_h, grid_w = (
+ resized_height // patch_size,
+ resized_width // patch_size,
+ )
+ patches = patches.view(
+ batch_size,
+ grid_t,
+ temporal_patch_size,
+ channel,
+ grid_h,
+ patch_size,
+ grid_w,
+ patch_size,
+ )
+ patches = patches.permute(0, 1, 4, 6, 3, 2, 5, 7)
+ flatten_patches = patches.reshape(batch_size, grid_t * grid_h * grid_w, channel, patch_size, patch_size)
+
+ processed_images_grouped[shape] = flatten_patches
+ processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_grids = reorder_images(processed_grids, grouped_images_index)
+ pixel_values = torch.cat(processed_images, dim=0)
+ image_grid_thw = torch.tensor(processed_grids)
+
+ return BatchFeature(
+ data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors
+ )
+
+
+class PaddleOCRVLProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ },
+ }
+
+
+class PaddleOCRVLProcessor(ProcessorMixin):
+ r"""
+ [`PaddleOCRVLProcessor`] offers all the functionalities of [`PaddleOCRVLImageProcessor`] and [`LLamaTokenizerFast`]. See the
+ [`~PaddleOCRVLProcessor.__call__`] and [`~PaddleOCRVLProcessor.decode`] for more information.
+ Args:
+ image_processor ([`PaddleOCRVLImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`LLamaTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
+ self.image_token = tokenizer.image_token
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ images: ImageInput = None,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ **kwargs: Unpack[PaddleOCRVLProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ PaddleOCRVLProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ if images is not None:
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
+ image_grid_thw = image_inputs["image_grid_thw"]
+
+ else:
+ image_inputs = {}
+ image_grid_thw = None
+
+ if not isinstance(text, list):
+ text = [text]
+
+ text = text.copy()
+
+ if image_grid_thw is not None:
+ index = 0
+ for i in range(len(text)):
+ while self.image_token in text[i]:
+ text[i] = text[i].replace(
+ self.image_token,
+ "<|placeholder|>"
+ * (
+ image_grid_thw[index].prod()
+ // self.image_processor.merge_size
+ // self.image_processor.merge_size
+ ),
+ 1,
+ )
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
+
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
+
+ return BatchFeature(data={**text_inputs, **image_inputs})
+
+
+class PaddleOCRVisionConfig(SiglipVisionConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PaddleOCRVisionModel`]. It is used to instantiate a
+ PaddleOCRVL vision encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the PaddleOCRVL
+ [PaddlePaddle/PaddleOCRVL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL) architecture.
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 1152):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 4304):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 27):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input images.
+ image_size (`int`, *optional*, defaults to 384):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ spatial_merge_size (`int`, *optional*, defaults to 2):
+ The size used for merging spatial dimensions.
+
+ Example:
+
+ ```python
+ >>> from transformers import PaddleOCRVisionConfig, PaddleOCRVisionModel
+
+ >>> # Initializing a PaddleOCRVisionConfig with PaddlePaddle/PaddleOCR-VL style configuration
+ >>> configuration = PaddleOCRVisionConfig()
+
+ >>> # Initializing a PaddleOCRVisionModel (with random weights) from the PaddlePaddle/PaddleOCR-VL style configuration
+ >>> model = PaddleOCRVisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "paddleocr_vl_vision"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size=1152,
+ intermediate_size=4304,
+ num_hidden_layers=27,
+ num_attention_heads=16,
+ num_channels=3,
+ image_size=384,
+ patch_size=14,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ spatial_merge_size=2,
+ **kwargs,
+ ):
+ super().__init__()
+ self.spatial_merge_size = spatial_merge_size
+
+
+class PaddleOCRTextConfig(Ernie4_5Config):
+ model_type = "paddleocr_vl_text"
+
+
+class PaddleOCRVLConfig(Qwen2VLConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PaddleOCRVLForConditionalGeneration`]. It is used to instantiate a
+ PaddleOCRVL model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ PaddleOCRVL [PaddlePaddle/PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL).
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PaddleOCRTextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PaddleOCRVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 100295):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 100296):
+ The video token index to encode the image prompt.
+ vision_start_token_id (`int`, *optional*, defaults to 101305):
+ The token index to denote start of vision input.
+ vision_end_token_id (`int`, *optional*, defaults to 101306):
+ The token index to denote end of vision input.
+
+ ```python
+ >>> from transformers import PaddleOCRVLForConditionalGeneration, PaddleOCRVLConfig
+
+ >>> # Initializing a PaddleOCRVL style configuration
+ >>> configuration = PaddleOCRVLConfig()
+
+ >>> # Initializing a model from the PaddleOCRVL style configuration
+ >>> model = PaddleOCRVLForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ sub_configs = {"vision_config": PaddleOCRVisionConfig, "text_config": PaddleOCRTextConfig}
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ image_token_id=100295,
+ video_token_id=100296,
+ vision_start_token_id=101305,
+ vision_end_token_id=101306,
+ **kwargs,
+ ):
+ super().__init__()
+
+
+class PaddleOCRProjector(nn.Module):
+ def __init__(self, config: PaddleOCRVLConfig):
+ super().__init__()
+ self.merge_kernel_size = (config.vision_config.spatial_merge_size, config.vision_config.spatial_merge_size)
+
+ hidden_size = config.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1]
+
+ self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-05)
+ self.linear_1 = nn.Linear(hidden_size, hidden_size, bias=True)
+ self.act = GELUActivation()
+ self.linear_2 = nn.Linear(hidden_size, config.text_config.hidden_size, bias=True)
+
+ def forward(self, image_features: torch.Tensor, image_grid_thw: torch.Tensor) -> torch.Tensor:
+ image_features_chunks = image_features.split(image_grid_thw.prod(dim=1).tolist(), dim=0)
+ m1, m2 = self.merge_kernel_size
+
+ processed_features = []
+ for image_feature, image_grid in zip(image_features_chunks, image_grid_thw):
+ image_feature = self.pre_norm(image_feature)
+ t, h, w = image_grid
+ d = image_feature.shape[-1]
+ h_block = h // m1
+ w_block = w // m2
+
+ image_feature = image_feature.reshape(t, h_block, m1, w_block, m2, d)
+ image_feature = image_feature.transpose(2, 3)
+ image_feature = image_feature.reshape(t * h_block * w_block, m1 * m2 * d)
+
+ hidden_states = self.linear_1(image_feature)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ processed_features.append(hidden_states)
+
+ return torch.cat(processed_features, dim=0)
+
+
+class PaddleOCRVisionRotaryEmbedding(VisionRotaryEmbedding):
+ pass
+
+
+class PaddleOCRRotaryEmbedding(Qwen2VLRotaryEmbedding):
+ pass
+
+
+class PaddleOCRMLP(Ernie4_5MLP):
+ def __init__(self, config: PaddleOCRTextConfig):
+ super().__init__()
+
+
+class PaddleOCRAttention(Qwen2_5OmniAttention):
+ def __init__(self, config: PaddleOCRVLConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+
+ self.attention_dropout = 0.0
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias)
+
+
+class PaddleOCRRMSNorm(Ernie4_5RMSNorm):
+ pass
+
+
+class PaddleOCRDecoderLayer(Ernie4_5DecoderLayer):
+ def __init__(self, config: PaddleOCRTextConfig, layer_idx: int):
+ super().__init__()
+
+
+@auto_docstring
+class PaddleOCRVLPreTrainedModel(PreTrainedModel):
+ config: PaddleOCRVLConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["PaddleOCRDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+
+ _can_record_outputs = {
+ "hidden_states": PaddleOCRDecoderLayer,
+ "attentions": PaddleOCRAttention,
+ }
+
+
+class PaddleOCRTextModel(PaddleOCRVLPreTrainedModel, Ernie4_5Model):
+ def __init__(self, config: PaddleOCRTextConfig):
+ super().__init__(config)
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = (
+ torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
+ text_position_ids = position_ids[0]
+ position_ids = position_ids[1:]
+ else:
+ text_position_ids = None
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=text_position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_embeddings=position_embeddings,
+ position_ids=text_position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+class PaddleOCRVisionModel(PaddleOCRVLPreTrainedModel):
+ config: PaddleOCRVisionConfig
+ main_input_name = "pixel_values"
+ input_modalities = "image"
+
+ def __init__(self, config: PaddleOCRVisionConfig):
+ super().__init__(config)
+
+ self.vision_model = PaddleOCRVisionTransformer(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ cu_seqlens: torch.Tensor,
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
+ **kwargs,
+ ) -> BaseModelOutputWithPooling:
+ """
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`):
+ The tensors corresponding to the input images.
+ cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`):
+ The cumulative sequence lengths of each image or video feature.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ return self.vision_model(
+ pixel_values=pixel_values,
+ cu_seqlens=cu_seqlens,
+ image_grid_thw=image_grid_thw,
+ )
+
+
+class PaddleOCRVisionEmbeddings(SiglipVisionEmbeddings):
+ def __init__(self, config: PaddleOCRVisionConfig):
+ super().__init__()
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ num_positions = self.position_embedding.weight.shape[0]
+
+ patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
+
+ dim = embeddings.shape[-1]
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(height, width),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return patch_pos_embed
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ batch_size, squence_len, channel, height, width = pixel_values.shape
+ target_dtype = self.patch_embedding.weight.dtype
+ pixel_values = pixel_values.reshape(batch_size * squence_len, channel, height, width)
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
+ embeddings = patch_embeds.flatten(-2).squeeze(-1)
+ embeddings = embeddings.reshape(batch_size, squence_len, -1)
+
+ start = 0
+ embeddings = embeddings.squeeze(0)
+ tmp_embeddings = []
+ for image_grid in image_grid_thw:
+ t, h, w = image_grid
+ end = start + t * h * w
+ image_embeddings = embeddings[start:end, :]
+ position_embedding = self.interpolate_pos_encoding(image_embeddings, h, w).squeeze(0).repeat(t, 1)
+ image_embeddings = image_embeddings + position_embedding
+ tmp_embeddings.append(image_embeddings)
+ start = end
+ embeddings = torch.concat(tmp_embeddings, dim=0)
+
+ return embeddings
+
+
+class PaddleOCRVisionAttention(VideoLlama3VisionAttention):
+ def __init__(self, config: PaddleOCRVisionConfig):
+ super().__init__()
+
+
+class PaddleOCRVisionMLP(SiglipMLP):
+ def __init__(self, config: PaddleOCRVisionConfig):
+ super().__init__()
+
+
+class PaddleOCRVisionEncoderLayer(VideoLlama3VisionEncoderLayer):
+ def __init__(self, config: PaddleOCRVisionConfig):
+ super().__init__()
+
+
+class PaddleOCRVisionEncoder(VideoLlama3VisionEncoder):
+ def __init__(self, config: PaddleOCRVisionConfig):
+ super().__init__()
+ embed_dim = config.hidden_size
+ num_heads = config.num_attention_heads
+ head_dim = embed_dim // num_heads
+ self.rotary_pos_emb = PaddleOCRVisionRotaryEmbedding(head_dim // 2)
+
+ def forward(
+ self,
+ inputs_embeds: torch.FloatTensor,
+ cu_seqlens: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
+ ) -> BaseModelOutput:
+ """
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`):
+ The cumulative sequence lengths of each image or video feature.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ The attention_mask used in forward function shape [batch_size X sequence_length] if not None.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ device = inputs_embeds.device
+ hidden_states = inputs_embeds
+ attention_mask = create_bidirectional_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ )
+ split_hids = []
+ split_wids = []
+ for t, h, w in image_grid_thw:
+ image_pids = torch.arange(t * h * w, device=device) % (h * w)
+ sample_hids = image_pids // w
+ sample_wids = image_pids % w
+ split_hids.append(sample_hids)
+ split_wids.append(sample_wids)
+ width_position_ids = torch.concat(split_wids, dim=0)
+ height_position_ids = torch.concat(split_hids, dim=0)
+
+ pids = torch.stack([height_position_ids, width_position_ids], dim=-1)
+ max_grid_size = pids.max() + 1
+ rotary_embeddings_max_grid = self.rotary_pos_emb(max_grid_size)
+ rotary_embeddings = rotary_embeddings_max_grid[pids].flatten(1)
+ rotary_embeddings = rotary_embeddings.repeat(1, 2)
+ position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin())
+
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ position_embeddings=position_embeddings,
+ )
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ )
+
+
+class PaddleOCRVisionTransformer(PaddleOCRVLPreTrainedModel):
+ def __init__(self, config: PaddleOCRVisionConfig):
+ super().__init__(config)
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = PaddleOCRVisionEmbeddings(config)
+ self.encoder = PaddleOCRVisionEncoder(config)
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ cu_seqlens: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
+ **kwargs,
+ ) -> BaseModelOutputWithPooling:
+ """
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`):
+ The tensors corresponding to the input images.
+ cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`):
+ The cumulative sequence lengths of each image or video feature.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention_mask used in forward function shape [batch_size X sequence_length] if not None.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ hidden_states = self.embeddings(pixel_values, image_grid_thw=image_grid_thw)
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ inputs_embeds=hidden_states,
+ cu_seqlens=cu_seqlens,
+ attention_mask=attention_mask,
+ image_grid_thw=image_grid_thw,
+ )
+
+ last_hidden_state = encoder_outputs.last_hidden_state
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=None,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class PaddleOCRVLModelOutputWithPast(Qwen2VLModelOutputWithPast):
+ pass
+
+
+class PaddleOCRVLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast):
+ pass
+
+
+class PaddleOCRVLModel(Qwen2VLModel):
+ _checkpoint_conversion_mapping = {"^model": "language_model"}
+ _keys_to_ignore_on_load_unexpected = ["packing_position_embedding", "vision_model.head"]
+
+ def __init__(self, config: PaddleOCRVLConfig):
+ super().__init__(config)
+ self.visual = PaddleOCRVisionModel._from_config(config.vision_config)
+ self.projector = PaddleOCRProjector(config)
+ self.language_model = PaddleOCRTextModel._from_config(config.text_config)
+ self.rope_deltas = None
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.language_model.embed_tokens = value
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0)
+ cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=image_grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)
+ vision_outputs = self.visual(
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ cu_seqlens=cu_seqlens,
+ )
+ image_embeds = vision_outputs.last_hidden_state
+ image_embeds = self.projector(image_embeds, image_grid_thw)
+ return image_embeds
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ return special_image_mask
+
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple, PaddleOCRVLModelOutputWithPast]:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+ if inputs_embeds is None:
+ inputs_embeds = self.language_model.embed_tokens(input_ids)
+
+ if pixel_values is not None:
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw).to(
+ inputs_embeds.device, inputs_embeds.dtype
+ )
+ image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds)
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if position_ids is None:
+ past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
+ if self.rope_deltas is None or past_key_values_length == 0:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids=input_ids,
+ image_grid_thw=image_grid_thw,
+ attention_mask=attention_mask,
+ )
+ self.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
+ delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ position_ids = position_ids + delta.to(position_ids.device)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ output = PaddleOCRVLModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+
+ return output
+
+
+class PaddleOCRVLForConditionalGeneration(Qwen2VLForConditionalGeneration):
+ _checkpoint_conversion_mapping = {
+ "^visual": "model.visual",
+ "^mlp_AR": "model.projector",
+ r"^model(?!(\.visual|\.projector|\.language_model))": "model.language_model",
+ }
+ _keys_to_ignore_on_load_unexpected = ["packing_position_embedding", "vision_model.head"]
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, PaddleOCRVLCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, PaddleOCRVLForConditionalGeneration
+
+ >>> model = PaddleOCRVLForConditionalGeneration.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
+ >>> processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "image": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg",
+ },
+ {"type": "text", "text": "OCR:"},
+ ],
+ }
+ ]
+
+ >>> inputs = processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_dict=True,
+ return_tensors="pt"
+ ).to(model.device)
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=1024)
+ >>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
+ >>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ >>> print(output_text)
+ ```
+ """
+ outputs: PaddleOCRVLModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ image_grid_thw=image_grid_thw,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ pixel_values=pixel_values,
+ rope_deltas=rope_deltas,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = outputs.last_hidden_state
+
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return PaddleOCRVLCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=outputs.rope_deltas,
+ )
+
+
+__all__ = [
+ "PaddleOCRVLForConditionalGeneration",
+ "PaddleOCRVLModel",
+ "PaddleOCRVLPreTrainedModel",
+ "PaddleOCRVisionTransformer",
+ "PaddleOCRVLConfig",
+ "PaddleOCRTextModel",
+ "PaddleOCRVisionModel",
+ "PaddleOCRVisionConfig",
+ "PaddleOCRTextConfig",
+ "PaddleOCRVLImageProcessor",
+ "PaddleOCRVLImageProcessorFast",
+ "PaddleOCRVLProcessor",
+]
diff --git a/src/transformers/models/paddleocr_vl/processing_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/processing_paddleocr_vl.py
new file mode 100644
index 000000000000..e7bd822d3feb
--- /dev/null
+++ b/src/transformers/models/paddleocr_vl/processing_paddleocr_vl.py
@@ -0,0 +1,135 @@
+# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
+# This file was automatically generated from src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_paddleocr_vl.py file directly. One of our CI enforces this.
+# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
+# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Union
+
+from ...image_processing_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+
+
+class PaddleOCRVLProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ },
+ }
+
+
+class PaddleOCRVLProcessor(ProcessorMixin):
+ r"""
+ [`PaddleOCRVLProcessor`] offers all the functionalities of [`PaddleOCRVLImageProcessor`] and [`LLamaTokenizerFast`]. See the
+ [`~PaddleOCRVLProcessor.__call__`] and [`~PaddleOCRVLProcessor.decode`] for more information.
+ Args:
+ image_processor ([`PaddleOCRVLImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`LLamaTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
+ self.image_token = tokenizer.image_token
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ images: ImageInput = None,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ **kwargs: Unpack[PaddleOCRVLProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ PaddleOCRVLProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ if images is not None:
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
+ image_grid_thw = image_inputs["image_grid_thw"]
+
+ else:
+ image_inputs = {}
+ image_grid_thw = None
+
+ if not isinstance(text, list):
+ text = [text]
+
+ text = text.copy()
+
+ if image_grid_thw is not None:
+ index = 0
+ for i in range(len(text)):
+ while self.image_token in text[i]:
+ text[i] = text[i].replace(
+ self.image_token,
+ "<|placeholder|>"
+ * (
+ image_grid_thw[index].prod()
+ // self.image_processor.merge_size
+ // self.image_processor.merge_size
+ ),
+ 1,
+ )
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
+
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
+
+ return BatchFeature(data={**text_inputs, **image_inputs})
+
+
+__all__ = ["PaddleOCRVLProcessor"]
diff --git a/tests/models/paddleocr_vl/__init__.py b/tests/models/paddleocr_vl/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py b/tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py
new file mode 100644
index 000000000000..65949ce402f0
--- /dev/null
+++ b/tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py
@@ -0,0 +1,507 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PaddleOCRVL model."""
+
+import copy
+import gc
+import unittest
+
+import pytest
+from parameterized import parameterized
+
+from transformers import (
+ AutoProcessor,
+ PaddleOCRVLConfig,
+ PaddleOCRVLForConditionalGeneration,
+ is_torch_available,
+)
+from transformers.testing_utils import (
+ backend_empty_cache,
+ require_flash_attn,
+ require_torch,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ floats_tensor,
+ ids_tensor,
+)
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+
+
+class PaddleOCRVLVisionText2TextModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ seq_length=13,
+ num_channels=3,
+ image_height=28,
+ image_width=28,
+ text_config={
+ "pad_token_id": 0,
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "vocab_size": 103424,
+ "head_dim": 128,
+ "hidden_act": "silu",
+ "hidden_dropout_prob": 0.0,
+ "hidden_size": 32,
+ "ignored_index": -100,
+ "image_token_id": 100295,
+ "intermediate_size": 32,
+ "max_position_embeddings": 512,
+ "model_type": "paddleocr_vl",
+ "num_attention_heads": 4,
+ "num_hidden_layers": 2,
+ "num_key_value_heads": 2,
+ "rms_norm_eps": 1e-05,
+ "rope_scaling": {"mrope_section": [16, 24, 24], "rope_type": "default", "type": "default"},
+ "rope_theta": 500000,
+ "tie_word_embeddings": False,
+ },
+ vision_start_token_id=101305,
+ vision_end_token_id=101306,
+ image_token_id=100295,
+ is_training=True,
+ vision_config={
+ "hidden_act": "gelu_pytorch_tanh",
+ "hidden_size": 144,
+ "intermediate_size": 32,
+ "layer_norm_eps": 1e-06,
+ "model_type": "paddleocr_vl",
+ "num_attention_heads": 4,
+ "num_channels": 3,
+ "num_hidden_layers": 2,
+ "pad_token_id": 0,
+ "patch_size": 14,
+ "spatial_merge_size": 2,
+ },
+ ):
+ self.parent = parent
+ self.bos_token_id = text_config["bos_token_id"]
+ self.eos_token_id = text_config["eos_token_id"]
+ self.pad_token_id = text_config["pad_token_id"]
+ self.num_hidden_layers = text_config["num_hidden_layers"]
+ self.num_attention_heads = text_config["num_attention_heads"]
+ self.hidden_size = text_config["hidden_size"]
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+ self.image_token_id = image_token_id
+ self.text_config = text_config
+ self.vision_config = vision_config
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_height = image_height
+ self.image_width = image_width
+ self.is_training = is_training
+ self.vocab_size = text_config["vocab_size"]
+ self.num_image_tokens = 1
+ self.seq_length = seq_length + self.num_image_tokens
+
+ def get_config(self):
+ return PaddleOCRVLConfig(
+ text_config=self.text_config,
+ vision_config=self.vision_config,
+ vision_start_token_id=self.vision_start_token_id,
+ image_token_id=self.image_token_id,
+ )
+
+ def prepare_config_and_inputs(self):
+ config = self.get_config()
+ patch_size = config.vision_config.patch_size
+ pixel_values = floats_tensor(
+ [
+ self.batch_size * (self.image_height * self.image_width) // (patch_size**2),
+ config.vision_config.num_channels,
+ patch_size,
+ patch_size,
+ ]
+ )
+
+ return config, pixel_values
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values = config_and_inputs
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
+
+ input_ids[:, :4] = torch.tensor([100273, 2969, 93963, 93919], dtype=input_ids.dtype, device=input_ids.device)
+ input_ids[:, 4] = self.vision_start_token_id
+ input_ids[:, 5 : 5 + self.num_image_tokens] = self.image_token_id
+ input_ids[:, -8] = self.vision_end_token_id
+ input_ids[:, -7:] = torch.tensor(
+ [93972, 2497, 93963, 23, 92267, 93963, 93919], dtype=input_ids.dtype, device=input_ids.device
+ )
+
+ inputs_dict = {
+ "pixel_values": pixel_values,
+ "image_grid_thw": torch.tensor([[1, 2, 2]] * self.batch_size, device=torch_device),
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class PaddleOCRVLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ """
+ Model tester for `PaddleOCRVLForConditionalGeneration`.
+ """
+
+ all_model_classes = (PaddleOCRVLForConditionalGeneration,) if is_torch_available() else ()
+ pipeline_model_mapping = {"image-text-to-text": PaddleOCRVLForConditionalGeneration}
+ _is_composite = True
+
+ def setUp(self):
+ self.model_tester = PaddleOCRVLVisionText2TextModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=PaddleOCRVLConfig, has_text_modality=False)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_mismatching_num_image_tokens(self):
+ """
+ Tests that an explicit error is thrown when the number of image tokens
+ doesn't match the number of image placeholders in the text.
+ We also test multi-image cases when one prompt has multiple image tokens.
+ """
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ model = model_class(config).to(torch_device)
+ model.eval()
+ curr_input_dict = copy.deepcopy(input_dict) # in-place modifications further
+ _ = model(**curr_input_dict) # successful forward with no modifications
+
+ # remove one image but leave all the image tokens in text
+ patch_size = config.vision_config.patch_size
+ one_img_length = (self.model_tester.image_height * self.model_tester.image_width) // (patch_size**2)
+ curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-one_img_length:, ...]
+ curr_input_dict["image_grid_thw"] = curr_input_dict["image_grid_thw"][-1:, ...]
+ with self.assertRaises(ValueError):
+ _ = model(**curr_input_dict)
+
+ # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
+ input_ids = curr_input_dict["input_ids"][:1]
+ pixel_values = curr_input_dict["pixel_values"][:one_img_length]
+ image_grid_thw = curr_input_dict["image_grid_thw"][:1]
+ input_ids = torch.cat([input_ids, input_ids], dim=0)
+
+ # one image and two image tokens raise an error
+ with self.assertRaises(ValueError):
+ _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
+
+ # two images and two image tokens don't raise an error
+ pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
+ image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
+ _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
+
+ # PaddleOCRVL has pixel_values shaped as (bs*patch_len, image_channels, patch_size, patch_size) so we can't slice to batches in generate
+ def prepare_config_and_inputs_for_generate(self, batch_size=2):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ # We don't want a few model inputs in our model input dictionary for generation tests
+ input_keys_to_ignore = [
+ # we don't want encoder-decoder models to start from filled decoder ids
+ "decoder_input_ids",
+ "decoder_attention_mask",
+ # we'll set cache use in each test differently
+ "use_cache",
+ # Ignore labels if it is in the input dict
+ "labels",
+ # model-specific exceptions should overload/overwrite this function
+ ]
+
+ # The diff from the general `prepare_config_and_inputs_for_generate` lies here
+ patch_size = config.vision_config.patch_size
+ filtered_image_length = (
+ batch_size * (self.model_tester.image_height * self.model_tester.image_width) // (patch_size**2)
+ )
+ filtered_inputs_dict = {
+ k: v[:batch_size, ...] if isinstance(v, torch.Tensor) else v
+ for k, v in inputs_dict.items()
+ if k not in input_keys_to_ignore
+ }
+ filtered_inputs_dict["pixel_values"] = inputs_dict["pixel_values"][:filtered_image_length]
+
+ # It is important set `eos_token_id` to `None` to avoid early stopping (would break for length-based checks)
+ text_gen_config = config.get_text_config(decoder=True)
+ if text_gen_config.eos_token_id is not None and text_gen_config.pad_token_id is None:
+ text_gen_config.pad_token_id = (
+ text_gen_config.eos_token_id
+ if isinstance(text_gen_config.eos_token_id, int)
+ else text_gen_config.eos_token_id[0]
+ )
+ text_gen_config.eos_token_id = None
+ text_gen_config.forced_eos_token_id = None
+
+ return config, filtered_inputs_dict
+
+ @unittest.skip(reason="PaddleOCRVL does not support.")
+ def test_generate_compile_model_forward_fullgraph(self):
+ pass
+
+ @unittest.skip(reason="PaddleOCRVL does not support.")
+ def test_multi_gpu_data_parallel_forward(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="PaddleOCRVL does not support beam search.")
+ def test_beam_sample_generate(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="PaddleOCRVL does not support beam search.")
+ def test_beam_search_generate(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="PaddleOCRVL does not support beam search.")
+ def test_beam_search_generate_dict_output(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="PaddleOCRVL does not support beam search.")
+ def test_beam_search_generate_dict_outputs_use_cache(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="PaddleOCRVL does not support beam search.")
+ def test_beam_sample_generate_dict_output(self):
+ pass
+
+ @unittest.skip(reason="PaddleOCRVL needs to apply weight conversions.")
+ def test_can_load_from_already_mapped_keys(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="PaddleOCRVL does not support beam search.")
+ def test_generate_from_inputs_embeds_1_beam_search(self, _, num_beams):
+ pass
+
+ @parameterized.expand([("random",), ("same",)])
+ @pytest.mark.generate
+ @unittest.skip(reason="PaddleOCRVL does not support assisted decoding.")
+ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="PaddleOCRVL does not support assisted decoding.")
+ def test_assisted_decoding_sample(self):
+ pass
+
+ @unittest.skip("PaddleOCRVL does not support this test.")
+ def test_model_is_small(self):
+ pass
+
+
+@require_torch
+@slow
+class PaddleOCRVLIntegrationTest(unittest.TestCase):
+ def setUp(self):
+ self.processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
+ self.messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg",
+ },
+ {"type": "text", "text": "OCR:"},
+ ],
+ }
+ ]
+
+ def tearDown(self):
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_small_model_integration_test(self):
+ model = (
+ PaddleOCRVLForConditionalGeneration.from_pretrained(
+ "PaddlePaddle/PaddleOCR-VL",
+ dtype="bfloat16",
+ )
+ .to(torch_device)
+ .eval()
+ )
+
+ inputs = self.processor.apply_chat_template(
+ self.messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ )
+
+ expected_input_ids_length = 211
+ assert expected_input_ids_length == len(inputs.input_ids[0])
+
+ expected_input_ids = [100273, 2969, 93963, 93919, 101305, 100295, 100295, 100295, 100295, 100295] # fmt: skip
+ assert expected_input_ids == inputs.input_ids[0].tolist()[:10]
+
+ expected_pixel_slice = torch.tensor(
+ [
+ [1.0000, 1.0000, 1.0000],
+ [1.0000, 1.0000, 1.0000],
+ [0.9922, 0.9922, 0.9922],
+ [1.0000, 1.0000, 1.0000],
+ [1.0000, 1.0000, 1.0000],
+ ],
+ dtype=torch.float32,
+ device="cpu",
+ )
+
+ assert torch.allclose(expected_pixel_slice, inputs.pixel_values[:5, :, 0, 0], atol=3e-3)
+
+ # verify generation
+ inputs = inputs.to(torch_device)
+ output = model.generate(**inputs, max_new_tokens=30)
+ result = self.processor.decode(output[0][inputs["input_ids"].shape[-1] : -1])
+
+ EXPECTED_DECODED_TEXT = "ηηθ"
+
+ self.assertEqual(
+ result,
+ EXPECTED_DECODED_TEXT,
+ )
+
+ def test_small_model_integration_test_batch(self):
+ model = (
+ PaddleOCRVLForConditionalGeneration.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
+ .to(torch_device)
+ .eval()
+ )
+
+ inputs = self.processor.apply_chat_template(
+ [self.messages, self.messages],
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding=True,
+ padding_side="left",
+ ).to(torch_device)
+
+ # it should not matter whether two images are the same size or not
+ output = model.generate(**inputs, max_new_tokens=30)
+ generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, output)]
+ result = self.processor.batch_decode(
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+
+ EXPECTED_DECODED_TEXT = ["ηηθ", "ηηθ"]
+
+ self.assertEqual(
+ result,
+ EXPECTED_DECODED_TEXT,
+ )
+
+ @require_flash_attn
+ @require_torch_accelerator
+ @pytest.mark.flash_attn_test
+ def test_small_model_integration_test_flashatt2(self):
+ model = (
+ PaddleOCRVLForConditionalGeneration.from_pretrained(
+ "PaddlePaddle/PaddleOCR-VL", dtype="bfloat16", attn_implementation="flash_attention_2"
+ )
+ .to(torch_device)
+ .eval()
+ )
+
+ inputs = self.processor.apply_chat_template(
+ self.messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ )
+
+ expected_input_ids_length = 211
+ assert expected_input_ids_length == len(inputs.input_ids[0])
+
+ expected_input_ids = [100273, 2969, 93963, 93919, 101305, 100295, 100295, 100295, 100295, 100295] # fmt: skip
+ assert expected_input_ids == inputs.input_ids[0].tolist()[:10]
+
+ expected_pixel_slice = torch.tensor(
+ [
+ [1.0000, 1.0000, 1.0000],
+ [1.0000, 1.0000, 1.0000],
+ [0.9922, 0.9922, 0.9922],
+ [1.0000, 1.0000, 1.0000],
+ [1.0000, 1.0000, 1.0000],
+ ],
+ dtype=torch.float32,
+ device="cpu",
+ )
+ assert torch.allclose(expected_pixel_slice, inputs.pixel_values[:5, :, 0, 0], atol=3e-3)
+
+ # verify generation
+ inputs = inputs.to(torch_device)
+ output = model.generate(**inputs, max_new_tokens=30)
+ result = self.processor.decode(output[0][inputs["input_ids"].shape[-1] : -1])
+
+ EXPECTED_DECODED_TEXT = "ηηθ"
+
+ self.assertEqual(
+ result,
+ EXPECTED_DECODED_TEXT,
+ )
+
+ @require_flash_attn
+ @require_torch_accelerator
+ @pytest.mark.flash_attn_test
+ def test_small_model_integration_test_batch_flashatt2(self):
+ model = (
+ PaddleOCRVLForConditionalGeneration.from_pretrained(
+ "PaddlePaddle/PaddleOCR-VL", dtype="bfloat16", attn_implementation="flash_attention_2"
+ )
+ .to(torch_device)
+ .eval()
+ )
+
+ inputs = self.processor.apply_chat_template(
+ [self.messages, self.messages],
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding=True,
+ padding_side="left",
+ ).to(torch_device)
+
+ # it should not matter whether two images are the same size or not
+ output = model.generate(**inputs, max_new_tokens=30)
+ generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, output)]
+ result = self.processor.batch_decode(
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+
+ EXPECTED_DECODED_TEXT = ["ηηθ", "ηηθ"]
+
+ self.assertEqual(
+ result,
+ EXPECTED_DECODED_TEXT,
+ )
diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py
index ac379f618823..41606b1e1b6b 100644
--- a/utils/check_config_attributes.py
+++ b/utils/check_config_attributes.py
@@ -58,6 +58,7 @@
"expert_layer_offset",
"expert_layer_period",
],
+ "PaddleOCRTextConfig": ["tie_word_embeddings"],
"Qwen2Config": ["use_sliding_window", "max_window_layers"],
"Qwen2MoeConfig": ["use_sliding_window", "max_window_layers"],
"Qwen2VLTextConfig": ["use_sliding_window", "max_window_layers"],
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 80d6a3f3223f..651e6726ec44 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -153,6 +153,10 @@
"SeamlessM4TCodeHifiGan", # Building part of bigger (tested) model.
"SeamlessM4TTextToUnitForConditionalGeneration", # Building part of bigger (tested) model.
"ChameleonVQVAE", # VQVAE here is used only for encoding (discretizing) and is tested as part of bigger model
+ "PaddleOCRVLModel", # Building part of bigger (tested) model. Tested implicitly through PaddleOCRVLForConditionalGeneration.
+ "PaddleOCRVisionModel", # Building part of bigger (tested) model. Tested implicitly through PaddleOCRVLForConditionalGeneration.
+ "PaddleOCRVisionTransformer", # Building part of bigger (tested) model. Tested implicitly through PaddleOCRVLForConditionalGeneration.
+ "PaddleOCRTextModel", # Building part of bigger (tested) model. Tested implicitly through PaddleOCRVLForConditionalGeneration.
"Qwen2VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2VLForConditionalGeneration.
"Qwen2_5_VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5_VLForConditionalGeneration.
"Qwen3VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3VLForConditionalGeneration.
@@ -382,6 +386,10 @@
"Emu3TextModel", # Building part of bigger (tested) model
"JanusVQVAE", # no autoclass for VQ-VAE models
"JanusVisionModel", # Building part of bigger (tested) model
+ "PaddleOCRVLModel", # Building part of bigger (tested) model
+ "PaddleOCRVisionModel", # Building part of bigger (tested) model
+ "PaddleOCRVisionTransformer", # Building part of bigger (tested) model
+ "PaddleOCRTextModel", # Building part of bigger (tested) model
"Qwen2_5OmniTalkerForConditionalGeneration", # Building part of a bigger model
"Qwen2_5OmniTalkerModel", # Building part of a bigger model
"Qwen2_5OmniThinkerForConditionalGeneration", # Building part of a bigger model