Skip to content

[Model] Add PP-OCRv5_server_rec and PP-OCRv5_mobile_rec models Support#44808

Merged
vasqu merged 19 commits intohuggingface:mainfrom
zhang-prog:feat/pp_ocrv5_rec_models
Mar 18, 2026
Merged

[Model] Add PP-OCRv5_server_rec and PP-OCRv5_mobile_rec models Support#44808
vasqu merged 19 commits intohuggingface:mainfrom
zhang-prog:feat/pp_ocrv5_rec_models

Conversation

@zhang-prog
Copy link
Copy Markdown
Contributor

No description provided.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have the core down now, now it's about the last details! Great work overall 🤗

logging,
requires_backends,
)
from ...utils.constants import ( # noqa: F401
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from ...utils.constants import ( # noqa: F401
from ...utils.constants import (

really unsure but we dont need the noqa I think

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

pad_size = {"height": 48, "width": 320}
do_resize = True
do_rescale = True
do_convert_rgb = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we retroactively change that for previous models (e.g. server/mobile det) re rbg?

Probably in a different PR

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I will make a new PR to solve this problem



@auto_docstring
@requires(backends=("torch",))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@requires(backends=("torch",))

Not 100% sure but it seems that the base processing does not use anything torch specific? If yes, then we don't need the requires backend within the post processing function

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +441 to +446
logits = self.head(outputs.last_hidden_state, **kwargs)

return BaseModelOutputWithNoAttention(
last_hidden_state=logits,
hidden_states=outputs.hidden_states,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logits = self.head(outputs.last_hidden_state, **kwargs)
return BaseModelOutputWithNoAttention(
last_hidden_state=logits,
hidden_states=outputs.hidden_states,
)
head_outputs = self.head(outputs.last_hidden_state, **kwargs)
return YourNewOutputClass(
last_hidden_state=head_outputs.last_hidden_states,
hidden_states=outputs.hidden_states,
head_hidden_states=head_outputs.hidden_states,
)

Just as a rough idea what I had in mind --> allow hidden states of both since the head model is quite sophisticated, I think it makes sense to have them as well

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, i see

Comment thread src/transformers/models/hgnet_v2/modular_hgnet_v2.py
batch_size=3,
image_size=[48, 320],
num_channels=3,
is_training=False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just as always for my interest: any support for training planned 👀

Comment on lines +151 to +161
@unittest.skip("PPOCRV5ServerRec does not has no attribute `hf_device_map`")
def test_cpu_offload(self):
pass

@unittest.skip("PPOCRV5ServerRec does not has no attribute `hf_device_map`")
def test_disk_offload_bin(self):
pass

@unittest.skip("PPOCRV5ServerRec does not has no attribute `hf_device_map`")
def test_disk_offload_safetensors(self):
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should change model_split_percents (attribute within the mixin); likely needs higher splits so e.g. model_split_percents = [0.5, 0.7, 0.8] # [0.5, 0.8]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[0.5, 0.7, 0.8] doesn’t work, but [0.5, 0.8] works.

However, the test_model_parallelism test is still failing.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super important imo, can be skipped for now

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 18, 2026

Btw main should be stable again, just need to merge/rebase with main

@zhang-prog zhang-prog requested a review from vasqu March 18, 2026 14:04
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly good with the current state, see my last comments. And sorry but gotta be strict about adding tests :/

Careful approval

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attention_mask: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None, # Not used but kept for signature matching in downstream modules

Comment on lines +278 to +280
# NOTE:
# Prevents TypeError from duplicate attention_mask arguments (passed both directly and in **kwargs).
# This parameter is a placeholder for compatibility and is not actually consumed by the function.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# NOTE:
# Prevents TypeError from duplicate attention_mask arguments (passed both directly and in **kwargs).
# This parameter is a placeholder for compatibility and is not actually consumed by the function.

just a nit: dont think we need to be too verbose

main_input_name = "pixel_values"
input_modalities = ("image",)
_can_record_outputs = {
"hidden_states": [PPOCRV5ServerRecConvLayer, PPOCRV5ServerRecBlock],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"hidden_states": [PPOCRV5ServerRecConvLayer, PPOCRV5ServerRecBlock],
"hidden_states": PPOCRV5ServerRecBlock,

Imo, I think we only want these because they are described in a way with config.depth

head_outputs = self.head(outputs.last_hidden_state, **kwargs)

return PPOCRV5ServerRecForTextRecognitionOutput(
last_hidden_state=head_outputs,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
last_hidden_state=head_outputs,
last_hidden_state=head_outputs.last_hidden_state,

config,
):
super().__init__(config)
# Use noqa to bypass the `unused in modular` check.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Use noqa to bypass the `unused in modular` check.

no need for the comment dont worry

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, hate to be stern but would definitely have modeling tests with integration tests - just for the simple reason that the backbone is different and we have a slightly different model albeit by very little

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reopening to add tests please

with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class))

def test_hidden_states_output(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add the head hidden states to check as well?

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 18, 2026

Hmm, maybe I was wrong on the gradient checkpointing:
FAILED tests/models/pp_ocrv5_server_rec/test_modeling_pp_ocrv5_server_rec.py::PPOCRV5ServerRecModelTest::test_gradient_checkpointing_backward_compatibility - ValueError: PPOCRV5ServerRecModel is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute gradient_checkpointing to modules of the model that uses checkpointing.

@zhang-prog
Copy link
Copy Markdown
Contributor Author

Yes, setting self.gradient_checkpointing = False is necessary to fix it, at least for now. :)

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't have much to add except for a few small nits + let's add tests for the mobile version as well please

Other than that, good to go!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reopening to add tests please

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 18, 2026

run-slow: hgnet_v2, pp_ocrv5_mobile_rec, pp_ocrv5_server_det, pp_ocrv5_server_rec

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/hgnet_v2", "models/pp_ocrv5_mobile_rec", "models/pp_ocrv5_server_det", "models/pp_ocrv5_server_rec"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 4cae9ac9 workflow commit (merge commit)
PR 6e5aaeff branch commit (from PR)
main 4ec84a02 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@vasqu vasqu enabled auto-merge March 18, 2026 16:43
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 18, 2026

@zhang-prog Will merge tomorrow probably, CI is struggling at the moment - nothing to do on your side 🤗

@vasqu vasqu disabled auto-merge March 18, 2026 17:30
@vasqu vasqu enabled auto-merge March 18, 2026 17:45
@vasqu vasqu added this pull request to the merge queue Mar 18, 2026
@vasqu vasqu removed this pull request from the merge queue due to a manual request Mar 18, 2026
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 18, 2026

run-slow: hgnet_v2, pp_ocrv5_mobile_rec, pp_ocrv5_server_det, pp_ocrv5_server_rec

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/hgnet_v2", "models/pp_ocrv5_mobile_rec", "models/pp_ocrv5_server_det", "models/pp_ocrv5_server_rec"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, hgnet_v2, pp_ocrv5_mobile_rec, pp_ocrv5_server_det, pp_ocrv5_server_rec

@vasqu vasqu enabled auto-merge March 18, 2026 20:02
@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN ee1f6921 workflow commit (merge commit)
PR d0e841d6 branch commit (from PR)
main 21950930 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@vasqu vasqu added this pull request to the merge queue Mar 18, 2026
Merged via the queue into huggingface:main with commit c55f650 Mar 18, 2026
29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants