Skip to content

[Model] Add PP-FormulaNet Model Support#45626

Open
zhang-prog wants to merge 12 commits intohuggingface:mainfrom
zhang-prog:feat/pp_formulanet
Open

[Model] Add PP-FormulaNet Model Support#45626
zhang-prog wants to merge 12 commits intohuggingface:mainfrom
zhang-prog:feat/pp_formulanet

Conversation

@zhang-prog
Copy link
Copy Markdown
Contributor

No description provided.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heya first round 🤗 you weren't lying when you said it was more complicated :D I've made fewer comments to focus on the core first: Restructure to a VLM and use existing patterns with our normal generate pipeline

Comment thread docs/source/en/model_doc/pp_formulanet.md Outdated
Comment thread docs/source/en/model_doc/pp_formulanet.md Outdated
from PIL import Image
from transformers import AutoProcessor, AutoModelForTextRecognition

model_path = "PaddlePaddle/PP-FormulaNet_plus-L_safetensors"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_path = "PaddlePaddle/PP-FormulaNet_plus-L_safetensors"
model_path = "PaddlePaddle/PP-FormulaNet_plus-L_safetensors" # or "PaddlePaddle/PP-FormulaNet-L_safetensors"

Not sure but in the docs 2 have been mentioned

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebump

Comment thread src/transformers/models/pp_formulanet/configuration_pp_formulanet.py Outdated
Comment thread src/transformers/models/pp_formulanet/modular_pp_formulanet.py Outdated
Comment thread src/transformers/models/pp_formulanet/modular_pp_formulanet.py Outdated
Comment on lines +199 to +202
s = news
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should compile the regex outside the loops, probably similar above

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebump

Comment thread src/transformers/models/pp_formulanet/modular_pp_formulanet.py Outdated
"""
text = self.remove_chinese_text_wrapping(text)
try:
from ftfy import fix_text
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a fan of an extra dependency tbh but ig it is too complicated/long to adopt here

import torch


class PPFormulaNetModelTester:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can then use our VLM tester instead

class VLMModelTester:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Atp because we have a special encoder-decoder which does not fit the standard VLM style, not sure if it really fits - maybe a classic encoder-decoder approach might be better

@zhang-prog
Copy link
Copy Markdown
Contributor Author

@vasqu I’ve restructured the PPFormulaNet into a VLM. Some unit tests are still failing and I’m fixing them, but that shouldn’t block you from reviewing the latest model structure code. PTAL.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better, I focused further on the model structure - I think the core is good now, now it's details and how make it fit within our style

from PIL import Image
from transformers import AutoProcessor, AutoModelForTextRecognition

model_path = "PaddlePaddle/PP-FormulaNet_plus-L_safetensors"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebump

Comment thread src/transformers/models/pp_formulanet/modular_pp_formulanet.py

@auto_docstring(checkpoint="PaddlePaddle/PPFormulaNet_plus-L_safetensors")
@strict
class PPFormulaNetTextConfig(PreTrainedConfig):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class PPFormulaNetTextConfig(PreTrainedConfig):
class PPFormulaNetTextConfig(MBartConfig):

We should inherit from Mbart directly, that way we don't have to think too much what is actually needed

Comment on lines +78 to +79
max_length (`int`, *optional*, defaults to 1537):
Controls the maximum length to use by one of the truncation/padding parameters.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might be searching for max_position_embeddings instead or at least it should not be part of the model but the tokenizer. Probably from the old model pattern you had where you manually called generate


@auto_docstring(
checkpoint="PaddlePaddle/PPFormulaNet_plus-L_safetensors"
) # or "PaddlePaddle/PP-FormulaNet-L_safetensors"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) # or "PaddlePaddle/PP-FormulaNet-L_safetensors"
)

tbh, would mention it in the model docs (model_doc/pp_formulanet.md) but not here because the default values are valid for that checkpoint - we only search for one example here

Comment on lines +402 to +411
decoder_outputs = self.language_model.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=image_features,
encoder_attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
**kwargs,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
decoder_outputs = self.language_model.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=image_features,
encoder_attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
**kwargs,
)
decoder_outputs = self.language_model(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=image_features,
encoder_attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
**kwargs,
)

Like mentioned before would like to move away from the ForCausalLM model and use the decoder directly

Comment on lines +436 to +437
def _prepare_encoder_decoder_kwargs_for_generation(self, *args, **kwargs):
return GenerationMixin._prepare_encoder_decoder_kwargs_for_generation(*args, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _prepare_encoder_decoder_kwargs_for_generation(self, *args, **kwargs):
return GenerationMixin._prepare_encoder_decoder_kwargs_for_generation(*args, **kwargs)
def _prepare_encoder_decoder_kwargs_for_generation(self, *args, **kwargs):
raise AttributeError()

I think you just don't want to inherit? That tells modular not to

Comment on lines +439 to +440
def get_encoder(self):
return self.model.vision_tower
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_encoder(self):
return self.model.vision_tower

encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this fit more to image_last_hidden_state? You want the last (pooled) feature, not the set of hidden states across all of this

Imo, we can even leave this completely out imo as the encoder is everything image-related. The output class should be new and explain that the encoder == vision encoder hence different expected shapes and all

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Comment on lines +376 to +386
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: torch.Tensor | None = None,
decoder_input_ids: torch.LongTensor | None = None,
decoder_attention_mask: torch.LongTensor | None = None,
decoder_inputs_embeds: torch.FloatTensor | None = None,
encoder_outputs: list[torch.FloatTensor] | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
**kwargs,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: torch.Tensor | None = None,
decoder_input_ids: torch.LongTensor | None = None,
decoder_attention_mask: torch.LongTensor | None = None,
decoder_inputs_embeds: torch.FloatTensor | None = None,
encoder_outputs: list[torch.FloatTensor] | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
**kwargs,
pixel_values: torch.FloatTensor | None = None,
attention_mask: torch.Tensor | None = None, # TODO check if this is really used, likely to be removed as well
decoder_input_ids: torch.LongTensor | None = None,
decoder_attention_mask: torch.LongTensor | None = None,
decoder_inputs_embeds: torch.FloatTensor | None = None,
encoder_outputs: list[torch.FloatTensor] | None = None,
past_key_values: Cache | None = None,
use_cache: bool | None = None,
**kwargs,

Noticing that we don't need those - we have pure images, no associated text in the encoder so we can leave/remove them

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

main input name should be pixel values not sure if that is already the case within the pretrained model :D

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameters still need to remain in the argument list; otherwise, when calling self.language_model(..., **kwargs), it will raise errors like:

got multiple values for keyword argument 'attention_mask'
got multiple values for keyword argument 'input_ids'

- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@auto_docstring pls

Comment on lines +425 to +426
def get_encoder(self):
return self.vision_tower
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same deletion here, get_encoder accepts a modality arg and is defined in parent

Comment thread tests/models/pp_formulanet/test_modeling_pp_formulanet.py
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, already looking good! Left a few comments on some less critical parts but would be still nice to fix/change 🤗


import httpx
from PIL import Image
from transformers import AutoProcessor, PPFormulaNetForConditionalGeneration
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto model please

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
return BatchFeature({**image_inputs})

def normalize(self, s: str) -> str:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: lets avoid short letter and just use text or similar

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +198 to +200
rule_noletter_noletter = re.compile(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter))
rule_noletter_letter = re.compile(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter))
rule_letter_noletter = re.compile(r"(%s)\s+?(%s)" % (letter, noletter))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, would it make sense to be more extreme and have those regex at init time once? Same below

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +362 to +363
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we mention with a small comment that we only keep this in the signature for generate compatibility?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if encoder_outputs is None:
encoder_outputs = self.get_image_features(pixel_values, **kwargs)
elif encoder_outputs.pooler_output is None:
encoder_outputs.pooler_output = self.multi_modal_projector(encoder_outputs.last_hidden_state)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo, we shouldn't need this. Maybe we should either

  1. Move the projector into the encoder as well
  2. Adjust the generation pipeline where we prepare the encoder outputs to instead call get image features

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +133 to +134
# test_torch_exportable = False
# model_split_percents = [0.5, 0.9]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# test_torch_exportable = False
# model_split_percents = [0.5, 0.9]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread tests/models/pp_formulanet/test_modeling_pp_formulanet.py
Comment on lines +255 to +257
@unittest.skip(reason="PPFormulaNet does not small")
def test_model_is_small(self):
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we try? :D

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, passed

Comment on lines +284 to +287
@pytest.mark.generate
@unittest.skip(reason="PPFormulaNet does not support beam search.")
def test_beam_sample_generate(self):
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to fix but also not that big of a deal

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, beam search tests are all passed

@unittest.skip(
reason="GenerationMixin._expand_inputs_for_generation() got multiple values for keyword argument 'input_ids'"
)
def test_generate_continue_from_past_key_values(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, should be fixed imo if possible - maybe overriding the test or something else

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did try that, but it failed :(
I think it may be related to the model’s special architecture, so for now I kept it skipped.

image

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the rtol/atol is maybe too low but yea no worries we can keep it skipped, not a high prio imo

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, pp_formulanet

@zhang-prog zhang-prog requested a review from vasqu April 29, 2026 09:32
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Carefully approving because it's only small stuff now 🤗 i will check in with run-slow in a sec as well just as sanity check

def __init__(self, config):
super().__init__(config)

config.vision_config.decoder_hidden_size = config.text_config.hidden_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be necessary and I'd rather adjust the values in the config from the get go

Comment on lines +404 to +405
if encoder_outputs is None:
encoder_outputs = self.get_image_features(pixel_values, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we now follow the full encoder-decoder structure, it would be nicer to stay closer to them e.g.

if encoder_outputs is None:
encoder_outputs: BaseModelOutput = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs,
)
elif not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)

We can still keep get image features, it just acts more as a nice utility then, not as core forward part

if encoder_outputs is None:
encoder_outputs = self.get_image_features(pixel_values, **kwargs)

image_features = encoder_outputs.pooler_output.to(self.decoder.device, self.decoder.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it actually needed?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebump, maybe missed to commit it :D

@unittest.skip(
reason="GenerationMixin._expand_inputs_for_generation() got multiple values for keyword argument 'input_ids'"
)
def test_generate_continue_from_past_key_values(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the rtol/atol is maybe too low but yea no worries we can keep it skipped, not a high prio imo

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 29, 2026

run-slow: pp_formulanet

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/pp_formulanet"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN af86d363 workflow commit (merge commit)
PR 74240ac5 branch commit (from PR)
main a374d990 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants