Skip to content

[Model] Add PP-DocLayoutV2 Model Support#43018

Merged
vasqu merged 31 commits intohuggingface:mainfrom
zhang-prog:feat/pp_doclayout_v2
Feb 27, 2026
Merged

[Model] Add PP-DocLayoutV2 Model Support#43018
vasqu merged 31 commits intohuggingface:mainfrom
zhang-prog:feat/pp_doclayout_v2

Conversation

@zhang-prog
Copy link
Copy Markdown
Contributor

@zhang-prog zhang-prog commented Dec 23, 2025

What does this PR do?

This PR adds PP-DocLayoutV2 model to Hugging Face Transformers from PaddleOCR.

Relevant Links:

PaddleOCR
https://huggingface.co/PaddlePaddle/PP-DocLayoutV2_safetensors

Usage

Use a pipeline

import requests
from PIL import Image
from transformers import pipeline

image = Image.open(requests.get("https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/layout_demo.jpg", stream=True).raw)
layout_detector = pipeline("object-detection", model="PaddlePaddle/PP-DocLayoutV2_safetensors")
result = layout_detector(image)
print(result)

Load model directly

from transformers import AutoImageProcessor, AutoModelForObjectDetection

model_path = "PaddlePaddle/PP-DocLayoutV2_safetensors"
model = AutoModelForObjectDetection.from_pretrained(model_path)
image_processor = AutoImageProcessor.from_pretrained(model_path)

image = Image.open(requests.get("https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/layout_demo.jpg", stream=True).raw)
inputs = image_processor(images=image, return_tensors="pt")

outputs = model(**inputs)
results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]))
for result in results:
    for idx, (score, label_id, box) in enumerate(zip(result["scores"], result["labels"], result["boxes"])):
        score, label = score.item(), label_id.item()
        box = [round(i, 2) for i in box.tolist()]
        print(f"Order {idx + 1}: {model.config.id2label[label]}: {score:.2f} {box}")

@zhang-prog zhang-prog changed the title init [Model] Add PP-DocLayoutV2 Model Support Dec 23, 2025
@ArthurZucker
Copy link
Copy Markdown
Collaborator

cc @molbap if you have time!

@molbap
Copy link
Copy Markdown
Contributor

molbap commented Jan 6, 2026

Reviewing!

Copy link
Copy Markdown
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the addition! Left a few comments to start cleaning up/aligning with library standards. Let me know if you have any question and I'll re-review once addressed 🤗

logits = outputs.logits
order_logits = outputs.order_logits

order_seqs = get_order(order_logits)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more explicit naming for get_order please

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed.

Comment on lines +72 to +155
def _default_id2label() -> dict[int, str]:
return {
0: "abstract",
1: "algorithm",
2: "aside_text",
3: "chart",
4: "content",
5: "formula",
6: "doc_title",
7: "figure_title",
8: "footer",
9: "footer",
10: "footnote",
11: "formula_number",
12: "header",
13: "header",
14: "image",
15: "formula",
16: "number",
17: "paragraph_title",
18: "reference",
19: "reference_content",
20: "seal",
21: "table",
22: "text",
23: "text",
24: "vision_footnote",
}


def _default_threshold_mapping() -> dict[str, float]:
return {
"abstract": 0.50,
"algorithm": 0.50,
"aside_text": 0.50,
"chart": 0.50,
"content": 0.50,
"formula": 0.40,
"doc_title": 0.40,
"figure_title": 0.50,
"footer": 0.50,
"footnote": 0.50,
"formula_number": 0.50,
"header": 0.50,
"image": 0.50,
"number": 0.50,
"paragraph_title": 0.40,
"reference": 0.50,
"reference_content": 0.50,
"seal": 0.45,
"table": 0.50,
"text": 0.40,
"vision_footnote": 0.50,
}


def _default_order_map() -> dict[str, int]:
return {
"abstract": 4,
"algorithm": 2,
"aside_text": 14,
"chart": 1,
"content": 5,
"display_formula": 7,
"doc_title": 8,
"figure_title": 6,
"footer": 11,
"footer_image": 11,
"footnote": 9,
"formula_number": 13,
"header": 10,
"header_image": 10,
"image": 1,
"inline_formula": 2,
"number": 3,
"paragraph_title": 0,
"reference": 2,
"reference_content": 2,
"seal": 12,
"table": 1,
"text": 2,
"vertical_text": 15,
"vision_footnote": 6,
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove these helpers and simply use the configuration directly, augmenting it with threshold values for instance

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

`list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
"""
return postprocess(outputs=outputs, threshold=threshold, target_sizes=target_sizes)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe would be clearer to have this postprocess method as a class method, no? or at least closer for readability

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, done.

self.dense = nn.Linear(config.hidden_size, self.heads * 2 * self.head_size)

def forward(self, inputs, attn_mask_1d):
B, N, _ = inputs.shape
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general, let's avoid single-letter variables please!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed.

Comment on lines +878 to +882
if self.tril_mask:
lower = torch.tril(torch.ones([N, N], dtype=torch.float32, device=logits.device))
lower = lower.bool().unsqueeze(0).unsqueeze(0)
logits = logits - lower.to(logits.dtype) * 1e4
pair_mask = torch.logical_or(pair_mask.bool(), lower)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I understand correctly, tril_mask is always true, so the attribute can be removed and this branch too

def box_rel_encoding(src_boxes: torch.Tensor, tgt_boxes: torch.Tensor = None, eps: float = 1e-5):
if tgt_boxes is None:
tgt_boxes = src_boxes
assert src_boxes.shape[-1] == 4 and tgt_boxes.shape[-1] == 4
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no asserts, in general

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.



def get_sine_pos_embed(
x: torch.Tensor, num_pos_feats: int, temperature: float = 10000.0, scale: float = 100.0, exchange_xy: bool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exchange_xy is always False here. Also for x same comment for single letter variables

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines +1270 to +1272
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These three arguments are deprecated now. the first two need the decorator @check_model_inputs, and return_dict now uses @can_return_tuple. Then you'll need to specify in the model class attributes the modules that can record outputs with something close to that I suppose:

      _can_record_outputs = {
          "hidden_states": OutputRecorder(PPDocLayoutV2DecoderLayer, index=0),
          "attentions": [
              OutputRecorder(PPDocLayoutV2MultiheadAttention, index=1),
              OutputRecorder(PPDocLayoutV2MultiscaleDeformableAttention, index=1),
          ],
      }

Copy link
Copy Markdown
Contributor Author

@zhang-prog zhang-prog Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I’ve made the change, but it’s causing a test failure.

I duplicated PPDocLayoutV2HybridEncoder from RTDetrHybridEncoder, but the RTDetrHybridEncoder uses a deprecated method that makes encoder_hidden_states, encoder_attentions None. This is breaking the test_attention_outputs and check_hidden_states_output tests.

Any suggestions for a fix?

image

The test logs are below:

image image

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah indeed the base model uses a deprecated method as well. let me check and get back to you soon

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any ideas on how we can fix this?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from modifying RT-DETR and updating it to standards, no unfortunately. doing another review today

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks.
So, is there a plan to implement this fix in the near future? I’m asking because this issue affects multiple models.
Alternatively, I’m happy to make the RT-DETR changes myself for now if that would be helpful.
Please let me know the best way to proceed.

Comment thread src/transformers/models/pp_doclayout_v2/modular_pp_doclayout_v2.py

# custom
if rel_2d_pos is not None:
attention_scores += rel_2d_pos
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this will not work with fa2/etc though. LayoutLMv3 doesn't use the recent attention_interface, so this will need to be revamped afterwards, I'd prefer to update it now and implement a proper eager_attention_forward

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. So what’s my next step here? Is there a good reference model I can look at for this update?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping this, but yea something along bert

def eager_attention_forward(

FA and flex won't work, we could make SDPA work by integrating the rel bias into the mask directly. Check out t5 #42453 (at least at the point of my last commit :D)

@zhang-prog
Copy link
Copy Markdown
Contributor Author

@molbap
PTAL.
There are still two issues that need to be discussed. Please review my response.
Thanks for your efforts! 🤗

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jan 7, 2026

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43018&sha=eae5dd

@zhang-prog zhang-prog requested a review from molbap January 7, 2026 08:42
@zhang-prog
Copy link
Copy Markdown
Contributor Author

@molbap
I have made some modifications according to the V3 review. Maybe it can be merged soon.🤗
PTAL.

Copy link
Copy Markdown
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I definitely think we should pull from #43098 or the other way around, would simplify this greatly!

Comment on lines +618 to +620
self.image_processor = (
PPDocLayoutV2ImageProcessor.from_pretrained(model_path) if is_vision_available() else None
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vision is required for this test to run so can be simplified

return result


class LayoutLMv3TextEmbeddingsCustom(LayoutLMv3TextEmbeddings):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this, this should not be prefixed by LayoutLMv3, but should be for this specific model. All modules should share the common prefix of the current model, here PPDocLayoutV2. Same for ReadingOrder, it should rather be something like PPDocLayoutV2ReadingOrder


# Normalize the attention scores to probabilities.
# Use the trick of the CogView paper to stabilize training
attention_probs = self.cogview_attention(attention_scores)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you confirm you are using the cogview attention as well? else, we can drop it in the eager path perhaps and use the new attention interface?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping, same question - otherwise sdpa will become impossible for now I think

new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has to be removed (we don't support output_attentions, output_hidden_states, return_dict, all are handled through decorators now)

return out


class LayoutLMv3SelfAttentionCustom(LayoutLMv3SelfAttention):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs to be renamed with proper prefixes PPDocLayoutV2SelfAttention (valid across the file, all model classes need to be prefixed properly)


qw_t = qw.transpose(1, 2)
kw_t = kw.transpose(1, 2)
logits = torch.einsum("bhmd,bhnd->bhmn", qw_t, kw_t) / (self.head_size**0.5)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no einsum/einops unless unavoidable. It seems similar to v3 again, can we modularize more?



def box_rel_encoding(src_boxes: torch.Tensor, tgt_boxes: torch.Tensor = None, eps: float = 1e-5):
if tgt_boxes is None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

target instead of tgt, etc

return out


class PositionRelationEmbedding(nn.Module):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be prefixed as well

Comment on lines +669 to +681
class LayoutLMv3SelfOutputCustom(LayoutLMv3SelfOutput):
pass


class LayoutLMv3IntermediateCustom(LayoutLMv3Intermediate):
pass


class LayoutLMv3OutputCustom(LayoutLMv3Output):
pass


class LayoutLMv3AttentionCustom(LayoutLMv3Attention):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should all be prefixed as well, and instead of custom, we can write e.g. class PPDocLayoutv2Attention(LayoutLMv3Attention)

encoder_output = encoder_output.last_hidden_state
tok = encoder_output[:, 1 : 1 + seq_len, :]
attn_1d = torch.arange(seq_len, device=device)[None, :] < num_pred[:, None]
logits_bh, _ = self.relative_head(tok, attn_1d)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

abbreviations to expand

@molbap
Copy link
Copy Markdown
Contributor

molbap commented Feb 4, 2026

Hello @zhang-prog ! Let me know if you want some help for this PR 🤗 happy to re-review if needed

@zhang-prog
Copy link
Copy Markdown
Contributor Author

@molbap Hi, Pablo! I’m currently refactoring the code based on the latest RT-DETR and PP-DocLayoutV3. I will be submitting a new commit this week. 🤗

@zhang-prog
Copy link
Copy Markdown
Contributor Author

@molbap I submitted my updates, but I found that LayoutLMv3SelfAttention and LayoutLMv3Encoder still depend on passing output_attentions, output_hidden_states, and return_dict. Because we are reusing this code modularly and can’t make changes on our end (reminiscent of the RT-DETR case, -.-), do you have any suggestions on how to handle this?

@molbap
Copy link
Copy Markdown
Contributor

molbap commented Feb 6, 2026

OK! I am indeed working on removing all of these old patterns here https://github.com/huggingface/transformers/pull/43590/changes#diff-418eaafaa5103cea9eb92c3b93c0b1d79aa420ea9c354764bd3e6d900657a9b5 but I'll make a smaller PR with just layoutlmv3 changes. apart from that, no other issues? re-reviewing then 🤗

Comment on lines +268 to +274
# backbone
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
freeze_backbone_batch_norms=True,
backbone_kwargs=None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's get rid of args here except for the backbone_config. In the model config, we need to add the correct config type with model_type

No need to call consolidate_backbone_kwargs imo, we don't want users to keep passing extra backbone-related args in the future

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, args removed, and seems that still need to call consolidate_backbone_kwargs_to_config to instantiate the backbone.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, you mean we need to make sure the backbone_config is indeed a config obj and not dict?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it will raise an error if backbone_config is not a config obj.

Comment on lines +833 to +834
@can_return_tuple
@check_model_inputs
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont' think we should be stacking these two decorators together

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_model_inputs removed.

Comment on lines +908 to +909
labels=labels,
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kwargs not passed here or used later, so we don't need any of output_xxx stuff here? In that case, prob we keep only can_return_tuple

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass kwargs

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main issue here is, do you envision a point where the model would need to track hidden states/attentions? IMO I think not, so passing kwargs might not be needed. However if you choose to pass them, would be better to type them as TransformersKwargs in the signature

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather keep kwargs and type it, we don't know how and whether it will be refactored. It certainly helps along modular inheritances either way

@zhang-prog
Copy link
Copy Markdown
Contributor Author

@molbap Pablo, Thank you for your efforts! --- And I just wanted to ask: when can the LayoutLMv3 changes be merged? I’d like to get this PR merged as soon as possible, ideally before Feb 13. 🤗

@zhang-prog
Copy link
Copy Markdown
Contributor Author

@vasqu fixed. please run slow tests again.🤗

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Feb 26, 2026

run-slow: pp_doclayout_v2

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/pp_doclayout_v2"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 81c5f216 workflow commit (merge commit)
PR cdace81e branch commit (from PR)
main d4cb8416 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@vasqu vasqu enabled auto-merge (squash) February 26, 2026 13:19
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Feb 26, 2026

Building docs, then merging thanks a lot for iterating and sticking through 🤗 big work

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Feb 26, 2026

Oh docs are not building will try to check later if you have some time now @zhang-prog

@vasqu vasqu disabled auto-merge February 26, 2026 17:46
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Feb 26, 2026

Ok I finally found the issue and it is complicated so I'd rather not expand on it 😅

I have a commit here ece7fca which fixes the issue but still needs to fill in the blanks; would be nice if you could do that then I'd merge

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Feb 26, 2026

run-slow: pp_doclayout_v2

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/pp_doclayout_v2"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 864bb487 workflow commit (merge commit)
PR 48363ffe branch commit (from PR)
main b812aa91 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, pp_doclayout_v2

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, pp_doclayout_v2

@zhang-prog
Copy link
Copy Markdown
Contributor Author

@vasqu Thanks! I have filled out all the blanks, maybe the docs can be built successfully now?

@vasqu vasqu enabled auto-merge (squash) February 27, 2026 09:07
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Feb 27, 2026

Thanks a lot! Yea should work now, the commit I added resolved it already but the contents where TODOs :D with your fill ins, we are good to go!

@vasqu vasqu merged commit 8e8b861 into huggingface:main Feb 27, 2026
25 checks passed
@zhang-prog
Copy link
Copy Markdown
Contributor Author

@vasqu Cool! Thanks a lot for your help.🤗

zvik pushed a commit to zvik/transformers that referenced this pull request Mar 1, 2026
* init

* add model_doc

* fix

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* let's try this

* try

* fixup docs, expecting autodocstring to fail

* is it this

* fix

* update docstring

* update date

---------

Co-authored-by: vasqu <antonprogamer@gmail.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants