From 6ed646b975d705ff417f0758d458b834d13103fe Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Mon, 20 Apr 2026 20:40:06 +0800 Subject: [PATCH 1/6] init --- docs/source/en/_toctree.yml | 2 + src/transformers/models/__init__.py | 1 + src/transformers/models/auto/auto_mappings.py | 2 + src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/models/slanet/__init__.py | 28 + .../models/slanet/configuration_slanet.py | 77 +++ .../models/slanet/modeling_slanet.py | 480 ++++++++++++++++++ .../models/slanet/modular_slanet.py | 398 +++++++++++++++ tests/models/slanet/__init__.py | 0 tests/models/slanet/test_modeling_slanet.py | 246 +++++++++ utils/check_repo.py | 4 + 11 files changed, 1239 insertions(+) create mode 100644 src/transformers/models/slanet/__init__.py create mode 100644 src/transformers/models/slanet/configuration_slanet.py create mode 100644 src/transformers/models/slanet/modeling_slanet.py create mode 100644 src/transformers/models/slanet/modular_slanet.py create mode 100644 tests/models/slanet/__init__.py create mode 100644 tests/models/slanet/test_modeling_slanet.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index aec6b14839cb..01a1dd2089b0 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1365,6 +1365,8 @@ title: SigLIP - local: model_doc/siglip2 title: SigLIP2 + - local: model_doc/slanet + title: SLANet - local: model_doc/slanext title: SLANeXt - local: model_doc/smollm3 diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 8aad0af6c303..b2b991bf7dde 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -383,6 +383,7 @@ from .shieldgemma2 import * from .siglip import * from .siglip2 import * + from .slanet import * from .slanext import * from .smollm3 import * from .smolvlm import * diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index 10e376b65956..cf04cb2b91d7 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -529,6 +529,7 @@ ("siglip2_vision_model", "Siglip2VisionConfig"), ("siglip_text_model", "SiglipTextConfig"), ("siglip_vision_model", "SiglipVisionConfig"), + ("slanet", "SLANetConfig"), ("slanext", "SLANeXtConfig"), ("smollm3", "SmolLM3Config"), ("smolvlm", "SmolVLMConfig"), @@ -933,6 +934,7 @@ ("seggpt", {"pil": "SegGptImageProcessorPil", "torchvision": "SegGptImageProcessor"}), ("siglip", {"pil": "SiglipImageProcessorPil", "torchvision": "SiglipImageProcessor"}), ("siglip2", {"pil": "Siglip2ImageProcessorPil", "torchvision": "Siglip2ImageProcessor"}), + ("slanet", {"torchvision": "SLANeXtImageProcessor"}), ("slanext", {"torchvision": "SLANeXtImageProcessor"}), ("smolvlm", {"pil": "SmolVLMImageProcessorPil", "torchvision": "SmolVLMImageProcessor"}), ("superglue", {"pil": "SuperGlueImageProcessorPil", "torchvision": "SuperGlueImageProcessor"}), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index deb1153d335e..e7a3fb0b1599 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1157,6 +1157,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): MODEL_FOR_TABLE_RECOGNITION_MAPPING_NAMES = OrderedDict( [ + ("slanet", "SLANetForTableRecognition"), ("slanext", "SLANeXtForTableRecognition"), ] ) diff --git a/src/transformers/models/slanet/__init__.py b/src/transformers/models/slanet/__init__.py new file mode 100644 index 000000000000..96d521f4972f --- /dev/null +++ b/src/transformers/models/slanet/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_slanet import * + from .modeling_slanet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/slanet/configuration_slanet.py b/src/transformers/models/slanet/configuration_slanet.py new file mode 100644 index 000000000000..815615684947 --- /dev/null +++ b/src/transformers/models/slanet/configuration_slanet.py @@ -0,0 +1,77 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/slanet/modular_slanet.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_slanet.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from huggingface_hub.dataclasses import strict + +from ...backbone_utils import consolidate_backbone_kwargs_to_config +from ...configuration_utils import PreTrainedConfig +from ...utils import auto_docstring +from ..auto import AutoConfig + + +@auto_docstring(checkpoint="PaddlePaddle/SLANet_plus_safetensors") +@strict +class SLANetConfig(PreTrainedConfig): + r""" + post_conv_out_channels (`int`, *optional*, defaults to 96): + Number of output channels for the post-encoder convolution layer. + out_channels (`int`, *optional*, defaults to 50): + Vocabulary size for the table structure token prediction head, i.e., the number of distinct structure + tokens the model can predict. + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states in the attention GRU cell and the structure/location prediction heads. + max_text_length (`int`, *optional*, defaults to 500): + Maximum number of autoregressive decoding steps (tokens) for the structure and location decoder. + csp_kernel_size (`int`, *optional*, defaults to 5): + The kernel size of the CSP layer. + csp_blocks_num (`int`, *optional*, defaults to 1): + Number of the CSP layer. + """ + + model_type = "slanet" + + sub_configs = {"backbone_config": AutoConfig} + post_conv_out_channels: int = 96 + out_channels: int = 50 + hidden_size: int = 256 + max_text_length: int = 500 + backbone_config: dict | PreTrainedConfig | None = None + + hidden_act: str = "hardswish" + csp_kernel_size: int = 5 + csp_blocks_num: int = 1 + + def __post_init__(self, **kwargs): + self.backbone_config, kwargs = consolidate_backbone_kwargs_to_config( + backbone_config=self.backbone_config, + default_config_type="pp_lcnet", + default_config_kwargs={ + "scale": 1, + "out_features": ["stage2", "stage3", "stage4", "stage5"], + "out_indices": [2, 3, 4, 5], + "divisor": 16, + }, + **kwargs, + ) + super().__post_init__(**kwargs) + + +__all__ = ["SLANetConfig"] diff --git a/src/transformers/models/slanet/modeling_slanet.py b/src/transformers/models/slanet/modeling_slanet.py new file mode 100644 index 000000000000..a02978ffbf0d --- /dev/null +++ b/src/transformers/models/slanet/modeling_slanet.py @@ -0,0 +1,480 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/slanet/modular_slanet.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_slanet.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from ... import initialization as init +from ...activations import ACT2CLS, ACT2FN +from ...backbone_utils import filter_output_hidden_states, load_backbone +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithNoAttention +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from .configuration_slanet import SLANetConfig + + +class SLANetPreTrainedModel(PreTrainedModel): + config: SLANetConfig + base_model_prefix = "slanet" + main_input_name = "pixel_values" + input_modalities = ("image",) + supports_gradient_checkpointing = False + _keep_in_fp32_modules_strict = [] + + @torch.no_grad() + def _init_weights(self, module): + """Initialize the weights""" + super()._init_weights(module) + + # Initialize GRUCell (replicates PyTorch default reset_parameters) + if isinstance(module, nn.GRUCell): + std = 1.0 / math.sqrt(module.hidden_size) if module.hidden_size > 0 else 0 + init.uniform_(module.weight_ih, -std, std) + init.uniform_(module.weight_hh, -std, std) + if module.bias_ih is not None: + init.uniform_(module.bias_ih, -std, std) + if module.bias_hh is not None: + init.uniform_(module.bias_hh, -std, std) + + # Initialize SLAHead layers + if isinstance(module, SLANetSLAHead): + std = 1.0 / math.sqrt(self.config.hidden_size * 1.0) + # Initialize structure_generator and loc_generator layers + for generator in (module.structure_generator,): + for layer in generator.children(): + if isinstance(layer, nn.Linear): + init.uniform_(layer.weight, -std, std) + if layer.bias is not None: + init.uniform_(layer.bias, -std, std) + + +@dataclass +@auto_docstring +class SLANetForTableRecognitionOutput(BaseModelOutput): + r""" + head_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Hidden-states of the SLANetSLAHead at each prediction step, varies up to max `self.config.max_text_length` states (depending on early exits). + head_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Attentions of the SLANetSLAHead at each prediction step, varies up to max `self.config.max_text_length` attentions (depending on early exits). + """ + + head_hidden_states: torch.FloatTensor | None = None + head_attentions: torch.FloatTensor | None = None + + +class SLANetAttentionGRUCell(nn.Module): + def __init__(self, input_size, hidden_size, num_embeddings): + super().__init__() + + self.input_to_hidden = nn.Linear(input_size, hidden_size, bias=False) + self.hidden_to_hidden = nn.Linear(hidden_size, hidden_size) + self.score = nn.Linear(hidden_size, 1, bias=False) + + self.rnn = nn.GRUCell(input_size + num_embeddings, hidden_size) + + def forward( + self, + prev_hidden: torch.FloatTensor, + batch_hidden: torch.FloatTensor, + char_onehots: torch.FloatTensor, + **kwargs: Unpack[TransformersKwargs], + ): + batch_hidden_proj = self.input_to_hidden(batch_hidden) + prev_hidden_proj = self.hidden_to_hidden(prev_hidden).unsqueeze(1) + + attention_scores = batch_hidden_proj + prev_hidden_proj + attention_scores = torch.tanh(attention_scores) + attention_scores = self.score(attention_scores) + + attn_weights = F.softmax(attention_scores, dim=1, dtype=torch.float32).to(attention_scores.dtype) + attn_weights = attn_weights.transpose(1, 2) + context = torch.matmul(attn_weights, batch_hidden).squeeze(1) + concat_context = torch.cat([context, char_onehots], 1) + hidden_states = self.rnn(concat_context, prev_hidden) + + return hidden_states, attn_weights + + +class SLANetMLP(nn.Module): + def __init__(self, hidden_size, out_channels, activation=None): + super().__init__() + self.fc1 = nn.Linear(hidden_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, out_channels) + self.act_fn = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = self.act_fn(hidden_states) + return hidden_states + + +class SLANetSLAHead(SLANetPreTrainedModel): + _can_record_outputs = { + "attentions": SLANetAttentionGRUCell, + } + + def __init__( + self, + config: dict | None = None, + **kwargs, + ): + super().__init__(config) + + self.structure_attention_cell = SLANetAttentionGRUCell( + config.post_conv_out_channels, config.hidden_size, config.out_channels + ) + self.structure_generator = SLANetMLP(config.hidden_size, config.out_channels) + + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @filter_output_hidden_states + def forward( + self, + hidden_states: torch.FloatTensor, + targets: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ): + features = torch.zeros( + (hidden_states.shape[0], self.config.hidden_size), dtype=torch.float32, device=hidden_states.device + ) + predicted_chars = torch.zeros(size=[hidden_states.shape[0]], dtype=torch.long, device=hidden_states.device) + + structure_preds_list = [] + structure_ids_list = [] + for _ in range(self.config.max_text_length + 1): + embedding_feature = F.one_hot(predicted_chars, self.config.out_channels).float() + features, _ = self.structure_attention_cell(features, hidden_states.float(), embedding_feature) + structure_step = self.structure_generator(features) + predicted_chars = structure_step.argmax(dim=1) + + structure_preds_list.append(structure_step) + structure_ids_list.append(predicted_chars) + if torch.stack(structure_ids_list, dim=1).eq(self.config.out_channels - 1).any(-1).all(): + break + structure_preds = F.softmax(torch.stack(structure_preds_list, dim=1), dim=-1, dtype=torch.float32).to( + hidden_states.dtype + ) + + return BaseModelOutput(last_hidden_state=structure_preds, hidden_states=structure_preds_list) + + +class SLANetConvLayer(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + activation: str = "hardswish", + groups: int = 1, + ): + super().__init__() + self.convolution = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + bias=False, + groups=groups, + ) + self.normalization = nn.BatchNorm2d(out_channels) + self.activation = ACT2FN[activation] if activation is not None else nn.Identity() + + def forward(self, input: Tensor) -> Tensor: + hidden_state = self.convolution(input) + hidden_state = self.normalization(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class SLANetDepthwiseSeparableConvLayer(nn.Module): + """ + Depthwise Separable Convolution Layer: Depthwise Conv -> Pointwise Conv + Core component of lightweight models (e.g., MobileNet, PP-LCNet) that significantly reduces + the number of parameters and computational cost. + """ + + def __init__( + self, + in_channels, + out_channels, + stride, + kernel_size, + activation, + ): + super().__init__() + self.depthwise_convolution = SLANetConvLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=stride, + groups=in_channels, + activation=activation, + ) + self.pointwise_convolution = SLANetConvLayer( + in_channels=in_channels, + kernel_size=1, + out_channels=out_channels, + stride=1, + activation=activation, + ) + + def forward(self, hidden_states): + hidden_states = self.depthwise_convolution(hidden_states) + hidden_states = self.pointwise_convolution(hidden_states) + + return hidden_states + + +class SLANetBottleneck(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + activation, + ): + super().__init__() + self.conv1 = SLANetConvLayer( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, activation=activation + ) + self.conv2 = SLANetDepthwiseSeparableConvLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + activation=activation, + ) + + def forward(self, hidden_states): + hidden_states = self.conv1(hidden_states) + hidden_states = self.conv2(hidden_states) + + return hidden_states + + +class SLANetCSPLayer(nn.Module): + """ + Cross Stage Partial (CSP) network layer. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + expand_ratio=0.5, + num_blocks=1, + activation="hardswish", + ): + super().__init__() + mid_channels = int(out_channels * expand_ratio) + self.main_conv = SLANetConvLayer(in_channels, mid_channels, 1, activation=activation) + self.short_conv = SLANetConvLayer(in_channels, mid_channels, 1, activation=activation) + self.final_conv = SLANetConvLayer(2 * mid_channels, out_channels, 1, activation=activation) + + self.blocks = nn.ModuleList() + for _ in range(num_blocks): + self.blocks.append( + SLANetBottleneck( + mid_channels, + mid_channels, + kernel_size, + activation, + ) + ) + + def forward(self, hidden_states): + hidden_states_short = self.short_conv(hidden_states) + + hidden_states_main = self.main_conv(hidden_states) + for block in self.blocks: + hidden_states_main = block(hidden_states_main) + + hidden_states = torch.cat((hidden_states_main, hidden_states_short), dim=1) + hidden_states = self.final_conv(hidden_states) + + return hidden_states + + +class SLANetChannelProjector(nn.Module): + def __init__(self, in_channel_list, out_channels, activation): + super().__init__() + self.layers = nn.ModuleList() + for i in range(len(in_channel_list)): + self.layers.append( + SLANetConvLayer( + in_channels=in_channel_list[i], out_channels=out_channels, kernel_size=1, activation=activation + ) + ) + + def forward(self, hidden_states): + projected_features = [] + for idx in range(len(self.layers)): + projected_features.append(self.layers[idx](hidden_states[idx])) + return projected_features + + +class SLANetCSPPAN(nn.Module): + """ + CSP-PAN: Path Aggregation Network with CSP layers + """ + + def __init__( + self, + in_channel_list, + config, + ): + super().__init__() + out_channels = config.post_conv_out_channels + activation = config.hidden_act + kernel_size = config.csp_kernel_size + csp_blocks_num = config.csp_blocks_num + + self.channel_projector = SLANetChannelProjector(in_channel_list, out_channels, activation) + + # build top-down blocks + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + self.top_down_blocks = nn.ModuleList() + for _ in range(len(in_channel_list) - 1, 0, -1): + self.top_down_blocks.append( + SLANetCSPLayer( + out_channels * 2, + out_channels, + kernel_size=kernel_size, + num_blocks=csp_blocks_num, + activation=activation, + ) + ) + + # build bottom-up blocks + self.downsamples = nn.ModuleList() + self.bottom_up_blocks = nn.ModuleList() + for _ in range(len(in_channel_list) - 1): + self.downsamples.append( + SLANetDepthwiseSeparableConvLayer( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=2, + activation=activation, + ) + ) + self.bottom_up_blocks.append( + SLANetCSPLayer( + out_channels * 2, + out_channels, + kernel_size=kernel_size, + num_blocks=csp_blocks_num, + activation=activation, + ) + ) + + def forward(self, hidden_states): + projected_features = self.channel_projector(hidden_states) + + top_down_features = [projected_features[-1]] + for top_down_block, low_level_feature in zip(self.top_down_blocks, reversed(projected_features[:-1])): + high_level_feature = top_down_features[-1] + upsampled_feature = F.interpolate( + high_level_feature, + size=low_level_feature.shape[-2:], + mode="nearest", + ) + fused_feature = top_down_block(torch.cat([upsampled_feature, low_level_feature], dim=1)) + top_down_features.append(fused_feature) + + pyramid_features = list(reversed(top_down_features)) + output_feature = pyramid_features[0] + for downsample_layer, bottom_up_block, high_level_feature in zip( + self.downsamples, self.bottom_up_blocks, pyramid_features[1:] + ): + downsampled_feature = downsample_layer(output_feature) + output_feature = bottom_up_block(torch.cat([downsampled_feature, high_level_feature], dim=1)) + + hidden_states = output_feature.flatten(2).transpose(1, 2) + return hidden_states + + +class SLANetModel(SLANetPreTrainedModel): + def __init__(self, config: SLANetConfig): + super().__init__(config) + self.backbone = load_backbone(config) + self.neck = SLANetCSPPAN(self.backbone.num_features[2:], config) + + self.post_init() + + @merge_with_config_defaults + @capture_outputs + def forward( + self, hidden_states: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs] + ) -> tuple[torch.FloatTensor] | SLANetForTableRecognitionOutput: + outputs = self.backbone(hidden_states, **kwargs) + hidden_states = self.neck(outputs.feature_maps) + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_states, + hidden_states=outputs.hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + SLANet Table Recognition model for table recognition tasks. Wraps the core SLANetPreTrainedModel + and returns outputs compatible with the Transformers table recognition API. + """ +) +class SLANetForTableRecognition(SLANetPreTrainedModel): + _keys_to_ignore_on_load_missing = ["num_batches_tracked"] + + def __init__(self, config: SLANetConfig): + super().__init__(config) + self.model = SLANetModel(config=config) + self.head = SLANetSLAHead(config=config) + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs] + ) -> tuple[torch.FloatTensor] | SLANetForTableRecognitionOutput: + outputs = self.model(pixel_values, **kwargs) + head_outputs = self.head(outputs.last_hidden_state, **kwargs) + return SLANetForTableRecognitionOutput( + last_hidden_state=head_outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + head_hidden_states=head_outputs.hidden_states, + head_attentions=head_outputs.attentions, + ) + + +__all__ = ["SLANetForTableRecognition", "SLANetPreTrainedModel", "SLANetSLAHead", "SLANetModel"] diff --git a/src/transformers/models/slanet/modular_slanet.py b/src/transformers/models/slanet/modular_slanet.py new file mode 100644 index 000000000000..6e0c2767fad3 --- /dev/null +++ b/src/transformers/models/slanet/modular_slanet.py @@ -0,0 +1,398 @@ +# Copyright 2026 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub.dataclasses import strict + +from ... import initialization as init +from ...backbone_utils import consolidate_backbone_kwargs_to_config, load_backbone +from ...configuration_utils import PreTrainedConfig +from ...modeling_outputs import BaseModelOutputWithNoAttention +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from ..auto import AutoConfig +from ..pp_lcnet.modeling_pp_lcnet import PPLCNetConvLayer +from ..slanext.configuration_slanext import SLANeXtConfig +from ..slanext.modeling_slanext import ( + SLANeXtForTableRecognitionOutput, + SLANeXtPreTrainedModel, + SLANeXtSLAHead, +) + + +logger = logging.get_logger(__name__) + + +@auto_docstring(checkpoint="PaddlePaddle/SLANet_plus_safetensors") +@strict +class SLANetConfig(SLANeXtConfig): + r""" + post_conv_out_channels (`int`, *optional*, defaults to 96): + Number of output channels for the post-encoder convolution layer. + out_channels (`int`, *optional*, defaults to 50): + Vocabulary size for the table structure token prediction head, i.e., the number of distinct structure + tokens the model can predict. + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states in the attention GRU cell and the structure/location prediction heads. + max_text_length (`int`, *optional*, defaults to 500): + Maximum number of autoregressive decoding steps (tokens) for the structure and location decoder. + csp_kernel_size (`int`, *optional*, defaults to 5): + The kernel size of the CSP layer. + csp_blocks_num (`int`, *optional*, defaults to 1): + Number of the CSP layer. + """ + + sub_configs = {"backbone_config": AutoConfig} + + vision_config = AttributeError() + backbone_config: dict | PreTrainedConfig | None = None + + post_conv_in_channels = AttributeError() + post_conv_out_channels: int = 96 + out_channels: int = 50 + hidden_size: int = 256 + max_text_length: int = 500 + + hidden_act: str = "hardswish" + csp_kernel_size: int = 5 + csp_blocks_num: int = 1 + + def __post_init__(self, **kwargs): + self.backbone_config, kwargs = consolidate_backbone_kwargs_to_config( + backbone_config=self.backbone_config, + default_config_type="pp_lcnet", + default_config_kwargs={ + "scale": 1, + "out_features": ["stage2", "stage3", "stage4", "stage5"], + "out_indices": [2, 3, 4, 5], + "divisor": 16, + }, + **kwargs, + ) + PreTrainedConfig.__post_init__(**kwargs) + + +class SLANetPreTrainedModel(SLANeXtPreTrainedModel): + base_model_prefix = "slanet" + supports_gradient_checkpointing = False + _keep_in_fp32_modules_strict = [] + + @torch.no_grad() + def _init_weights(self, module): + """Initialize the weights""" + PreTrainedModel._init_weights(module) + + # Initialize GRUCell (replicates PyTorch default reset_parameters) + if isinstance(module, nn.GRUCell): + std = 1.0 / math.sqrt(module.hidden_size) if module.hidden_size > 0 else 0 + init.uniform_(module.weight_ih, -std, std) + init.uniform_(module.weight_hh, -std, std) + if module.bias_ih is not None: + init.uniform_(module.bias_ih, -std, std) + if module.bias_hh is not None: + init.uniform_(module.bias_hh, -std, std) + + # Initialize SLAHead layers + if isinstance(module, SLANetSLAHead): + std = 1.0 / math.sqrt(self.config.hidden_size * 1.0) + # Initialize structure_generator and loc_generator layers + for generator in (module.structure_generator,): + for layer in generator.children(): + if isinstance(layer, nn.Linear): + init.uniform_(layer.weight, -std, std) + if layer.bias is not None: + init.uniform_(layer.bias, -std, std) + + +class SLANetForTableRecognitionOutput(SLANeXtForTableRecognitionOutput): + pass + + +class SLANetSLAHead(SLANeXtSLAHead): + pass + + +class SLANetConvLayer(PPLCNetConvLayer): + pass + + +class SLANetDepthwiseSeparableConvLayer(nn.Module): + """ + Depthwise Separable Convolution Layer: Depthwise Conv -> Pointwise Conv + Core component of lightweight models (e.g., MobileNet, PP-LCNet) that significantly reduces + the number of parameters and computational cost. + """ + + def __init__( + self, + in_channels, + out_channels, + stride, + kernel_size, + activation, + ): + super().__init__() + self.depthwise_convolution = SLANetConvLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=stride, + groups=in_channels, + activation=activation, + ) + self.pointwise_convolution = SLANetConvLayer( + in_channels=in_channels, + kernel_size=1, + out_channels=out_channels, + stride=1, + activation=activation, + ) + + def forward(self, hidden_states): + hidden_states = self.depthwise_convolution(hidden_states) + hidden_states = self.pointwise_convolution(hidden_states) + + return hidden_states + + +class SLANetBottleneck(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + activation, + ): + super().__init__() + self.conv1 = SLANetConvLayer( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, activation=activation + ) + self.conv2 = SLANetDepthwiseSeparableConvLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + activation=activation, + ) + + def forward(self, hidden_states): + hidden_states = self.conv1(hidden_states) + hidden_states = self.conv2(hidden_states) + + return hidden_states + + +class SLANetCSPLayer(nn.Module): + """ + Cross Stage Partial (CSP) network layer. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + expand_ratio=0.5, + num_blocks=1, + activation="hardswish", + ): + super().__init__() + mid_channels = int(out_channels * expand_ratio) + self.main_conv = SLANetConvLayer(in_channels, mid_channels, 1, activation=activation) + self.short_conv = SLANetConvLayer(in_channels, mid_channels, 1, activation=activation) + self.final_conv = SLANetConvLayer(2 * mid_channels, out_channels, 1, activation=activation) + + self.blocks = nn.ModuleList() + for _ in range(num_blocks): + self.blocks.append( + SLANetBottleneck( + mid_channels, + mid_channels, + kernel_size, + activation, + ) + ) + + def forward(self, hidden_states): + hidden_states_short = self.short_conv(hidden_states) + + hidden_states_main = self.main_conv(hidden_states) + for block in self.blocks: + hidden_states_main = block(hidden_states_main) + + hidden_states = torch.cat((hidden_states_main, hidden_states_short), dim=1) + hidden_states = self.final_conv(hidden_states) + + return hidden_states + + +class SLANetChannelProjector(nn.Module): + def __init__(self, in_channel_list, out_channels, activation): + super().__init__() + self.layers = nn.ModuleList() + for i in range(len(in_channel_list)): + self.layers.append( + SLANetConvLayer( + in_channels=in_channel_list[i], out_channels=out_channels, kernel_size=1, activation=activation + ) + ) + + def forward(self, hidden_states): + projected_features = [] + for idx in range(len(self.layers)): + projected_features.append(self.layers[idx](hidden_states[idx])) + return projected_features + + +class SLANetCSPPAN(nn.Module): + """ + CSP-PAN: Path Aggregation Network with CSP layers + """ + + def __init__( + self, + in_channel_list, + config, + ): + super().__init__() + out_channels = config.post_conv_out_channels + activation = config.hidden_act + kernel_size = config.csp_kernel_size + csp_blocks_num = config.csp_blocks_num + + self.channel_projector = SLANetChannelProjector(in_channel_list, out_channels, activation) + + # build top-down blocks + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + self.top_down_blocks = nn.ModuleList() + for _ in range(len(in_channel_list) - 1, 0, -1): + self.top_down_blocks.append( + SLANetCSPLayer( + out_channels * 2, + out_channels, + kernel_size=kernel_size, + num_blocks=csp_blocks_num, + activation=activation, + ) + ) + + # build bottom-up blocks + self.downsamples = nn.ModuleList() + self.bottom_up_blocks = nn.ModuleList() + for _ in range(len(in_channel_list) - 1): + self.downsamples.append( + SLANetDepthwiseSeparableConvLayer( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=2, + activation=activation, + ) + ) + self.bottom_up_blocks.append( + SLANetCSPLayer( + out_channels * 2, + out_channels, + kernel_size=kernel_size, + num_blocks=csp_blocks_num, + activation=activation, + ) + ) + + def forward(self, hidden_states): + projected_features = self.channel_projector(hidden_states) + + top_down_features = [projected_features[-1]] + for top_down_block, low_level_feature in zip(self.top_down_blocks, reversed(projected_features[:-1])): + high_level_feature = top_down_features[-1] + upsampled_feature = F.interpolate( + high_level_feature, + size=low_level_feature.shape[-2:], + mode="nearest", + ) + fused_feature = top_down_block(torch.cat([upsampled_feature, low_level_feature], dim=1)) + top_down_features.append(fused_feature) + + pyramid_features = list(reversed(top_down_features)) + output_feature = pyramid_features[0] + for downsample_layer, bottom_up_block, high_level_feature in zip( + self.downsamples, self.bottom_up_blocks, pyramid_features[1:] + ): + downsampled_feature = downsample_layer(output_feature) + output_feature = bottom_up_block(torch.cat([downsampled_feature, high_level_feature], dim=1)) + + hidden_states = output_feature.flatten(2).transpose(1, 2) + return hidden_states + + +class SLANetModel(SLANetPreTrainedModel): + def __init__(self, config: SLANetConfig): + super().__init__(config) + self.backbone = load_backbone(config) + self.neck = SLANetCSPPAN(self.backbone.num_features[2:], config) + + self.post_init() + + @merge_with_config_defaults + @capture_outputs + def forward( + self, hidden_states: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs] + ) -> tuple[torch.FloatTensor] | SLANetForTableRecognitionOutput: + outputs = self.backbone(hidden_states, **kwargs) + hidden_states = self.neck(outputs.feature_maps) + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_states, + hidden_states=outputs.hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + SLANet Table Recognition model for table recognition tasks. Wraps the core SLANetPreTrainedModel + and returns outputs compatible with the Transformers table recognition API. + """ +) +class SLANetForTableRecognition(SLANetPreTrainedModel): + _keys_to_ignore_on_load_missing = ["num_batches_tracked"] + + def __init__(self, config: SLANetConfig): + super().__init__(config) + self.model = SLANetModel(config=config) + self.head = SLANetSLAHead(config=config) + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs] + ) -> tuple[torch.FloatTensor] | SLANetForTableRecognitionOutput: + outputs = self.model(pixel_values, **kwargs) + head_outputs = self.head(outputs.last_hidden_state, **kwargs) + return SLANetForTableRecognitionOutput( + last_hidden_state=head_outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + head_hidden_states=head_outputs.hidden_states, + head_attentions=head_outputs.attentions, + ) + + +__all__ = ["SLANetConfig", "SLANetForTableRecognition", "SLANetPreTrainedModel", "SLANetSLAHead", "SLANetModel"] diff --git a/tests/models/slanet/__init__.py b/tests/models/slanet/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/slanet/test_modeling_slanet.py b/tests/models/slanet/test_modeling_slanet.py new file mode 100644 index 000000000000..cbb7c2eb3ac6 --- /dev/null +++ b/tests/models/slanet/test_modeling_slanet.py @@ -0,0 +1,246 @@ +# coding = utf-8 +# Copyright 2026 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the SLANet model.""" + +import inspect +import unittest + +import requests + +from transformers import ( + AutoImageProcessor, + AutoModelForTableRecognition, + SLANetConfig, + SLANetForTableRecognition, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import ( + require_torch, + require_vision, + slow, + torch_device, +) + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + +class SLANetModelTester: + def __init__( + self, + parent, + batch_size=2, + image_size=488, + num_channels=3, + post_conv_out_channels=16, + out_channels=1, + hidden_size=16, + max_text_length=1, + num_stages=5, + is_training=False, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.post_conv_out_channels = post_conv_out_channels + self.out_channels = out_channels + self.hidden_size = hidden_size + self.max_text_length = max_text_length + self.num_stages = num_stages + self.is_training = is_training + + def prepare_config_and_inputs_for_common(self): + config, pixel_values = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def get_config(self) -> SLANetConfig: + backbone_config = { + "model_type": "pp_lcnet", + "scale": 1, + "out_features": ["stage2", "stage3", "stage4", "stage5"], + "out_indices": [2, 3, 4, 5], + "block_configs": [ + [[3, 16, 16, 1, False]], + [[3, 16, 16, 2, False], [3, 16, 16, 1, False]], + [[3, 16, 16, 2, False], [3, 16, 16, 1, False]], + [ + [3, 16, 16, 2, False], + [5, 16, 16, 1, False], + [5, 16, 16, 1, False], + [5, 16, 16, 1, False], + [5, 16, 16, 1, False], + [5, 16, 16, 1, False], + ], + [[5, 16, 16, 2, True], [5, 16, 16, 1, True]], + ], + } + config = SLANetConfig( + backbone_config=backbone_config, + out_channels=self.out_channels, + hidden_size=self.hidden_size, + max_text_length=self.max_text_length, + post_conv_out_channels=self.post_conv_out_channels, + ) + + return config + + +@require_torch +class SLANetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (SLANetForTableRecognition,) if is_torch_available() else () + pipeline_model_mapping = {"image-feature-extraction": SLANetForTableRecognition} if is_torch_available() else {} + + has_attentions = False + test_resize_embeddings = False + test_torch_exportable = False + + def setUp(self): + self.model_tester = SLANetModelTester( + self, + batch_size=1, + image_size=488, + ) + self.config_tester = ConfigTester( + self, + config_class=SLANetConfig, + has_text_modality=False, + common_properties=[], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="SLANet does not use inputs_embeds") + def test_enable_input_require_grads(self): + pass + + @unittest.skip(reason="SLANet does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="SLANet does not use test_inputs_embeds_matches_input_ids") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="SLANet does not support input and output embeddings") + def test_model_get_set_embeddings(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + arg_names = [*signature.parameters.keys()] + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + expected_num_stages = self.model_tester.num_stages + + self.assertEqual(len(hidden_states), expected_num_stages + 1) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict.copy(), config, model_class) + + # Check that output_hidden_states also works via config (including backbone subconfig) + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + if config.backbone_config is not None: + config.backbone_config.output_hidden_states = True + check_hidden_states_output(inputs_dict.copy(), config, model_class) + + +@require_torch +@require_vision +@slow +class SLANetModelIntegrationTest(unittest.TestCase): + def setUp(self): + model_path = "PaddlePaddle/SLANet_plus_safetensors" + self.model = AutoModelForTableRecognition.from_pretrained(model_path, dtype=torch.float32).to(torch_device) + self.image_processor = AutoImageProcessor.from_pretrained(model_path) + url = "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg" + self.image = Image.open(requests.get(url, stream=True).raw) + + def test_inference_table_recognition_head(self): + inputs = self.image_processor(images=self.image, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = self.model(**inputs) + + pred_table_structure = self.image_processor.post_process_table_recognition(outputs)["structure"] + expected_table_structure = [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "
", + "", + "", + ] + + self.assertEqual(pred_table_structure, expected_table_structure) diff --git a/utils/check_repo.py b/utils/check_repo.py index 0816e834c64b..89b1e9e2dc10 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -194,6 +194,8 @@ "SeamlessM4TCodeHifiGan", # Building part of bigger (tested) model. "SeamlessM4TTextToUnitForConditionalGeneration", # Building part of bigger (tested) model. "ChameleonVQVAE", # VQVAE here is used only for encoding (discretizing) and is tested as part of bigger model + "SLANetSLAHead", # Buildinere is used only for encoding (discretizing) and is tested as part of bigger model + "SLANetModel", # Buildinere is used only for encoding (discretizing) and is tested as part of bigger model "SLANeXtSLAHead", # Buildinere is used only for encoding (discretizing) and is tested as part of bigger model "SLANeXtBackbone", # Building part of bigger (tested) model. Tested implicitly through SLANeXtForTableRecognition. "PPOCRV5MobileDetModel", # Building part of bigger (tested) model. Tested implicitly through PPOCRV5MobileDetForObjectDetection. @@ -460,6 +462,8 @@ "Emu3TextModel", # Building part of bigger (tested) model "JanusVQVAE", # no autoclass for VQ-VAE models "JanusVisionModel", # Building part of bigger (tested) model + "SLANetSLAHead", # Building part of bigger (tested) model + "SLANetModel", # Building part of bigger (tested) model "SLANeXtSLAHead", # Building part of bigger (tested) model "SLANeXtBackbone", # Building part of bigger (tested) model "PPOCRV5MobileDetModel", # Building part of bigger (tested) model From e9ca4e84364508188236e8f3ac05b82341b4ff26 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Mon, 20 Apr 2026 21:12:34 +0800 Subject: [PATCH 2/6] add model_doc --- docs/source/en/model_doc/slanet.md | 78 ++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 docs/source/en/model_doc/slanet.md diff --git a/docs/source/en/model_doc/slanet.md b/docs/source/en/model_doc/slanet.md new file mode 100644 index 000000000000..90a8e99845e5 --- /dev/null +++ b/docs/source/en/model_doc/slanet.md @@ -0,0 +1,78 @@ + +*This model was released on 2025-03-07 and added to Hugging Face Transformers on 2026-04-20.* + +# SLANet + +
+PyTorch +
+ +## Overview + +**SLANet** and **SLANet_plus** are part of a series of dedicated lightweight models for table structure recognition, focusing on accurately recognizing table structures in documents and natural scenes. For more details about the SLANet series model, please refer to the [official documentation](https://www.paddleocr.ai/latest/en/version3.x/module_usage/table_structure_recognition.html). + +## Model Architecture + +SLANet is a table structure recognition model developed by Baidu PaddlePaddle Vision Team. The model significantly improves the accuracy and inference speed of table structure recognition by adopting a CPU-friendly lightweight backbone network PP-LCNet, a high-low-level feature fusion module CSP-PAN, and a feature decoding module SLA Head that aligns structural and positional information. + +## Usage + +### Single input inference + +The example below demonstrates how to detect text with PP-OCRV5_Mobile_Det using the [`AutoModel`]. + + + + +```py +import requests +from PIL import Image +from transformers import AutoImageProcessor, AutoModelForTableRecognition + +model_path="PaddlePaddle/SLANet_plus_safetensors" +model = AutoModelForTableRecognition.from_pretrained(model_path, device_map="auto") +image_processor = AutoImageProcessor.from_pretrained(model_path) + +image = Image.open(requests.get("https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg", stream=True).raw) +inputs = image_processor(images=image, return_tensors="pt").to(model.device) +outputs = model(**inputs) + +results = image_processor.post_process_table_recognition(outputs) + +print(result['structure']) +print(result['structure_score']) +``` + + + + +## SLANetConfig + +[[autodoc]] SLANetConfig + +## SLANetForTableRecognition + +[[autodoc]] SLANetForTableRecognition + +## SLANetModel + +[[autodoc]] SLANetModel + +## SLANetSLAHead + +[[autodoc]] SLANetSLAHead + From 9b8a9fd927cd2bec5ac4d6c34f5cceae66fa0414 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Tue, 21 Apr 2026 14:27:46 +0800 Subject: [PATCH 3/6] fix --- docs/source/en/model_doc/slanet.md | 10 +- src/transformers/models/auto/auto_mappings.py | 1 - .../models/slanet/configuration_slanet.py | 8 +- .../models/slanet/modeling_slanet.py | 158 ++++++++------- .../models/slanet/modular_slanet.py | 182 +++++++----------- tests/models/slanet/test_modeling_slanet.py | 15 +- utils/fetch_hub_objects_for_ci.py | 1 + 7 files changed, 170 insertions(+), 205 deletions(-) diff --git a/docs/source/en/model_doc/slanet.md b/docs/source/en/model_doc/slanet.md index 90a8e99845e5..6ea542dd1412 100644 --- a/docs/source/en/model_doc/slanet.md +++ b/docs/source/en/model_doc/slanet.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on 2025-03-07 and added to Hugging Face Transformers on 2026-04-20.* +*This model was released on 2025-03-07 and added to Hugging Face Transformers on 2026-04-21.* # SLANet @@ -33,13 +33,15 @@ SLANet is a table structure recognition model developed by Baidu PaddlePaddle Vi ### Single input inference -The example below demonstrates how to detect text with PP-OCRV5_Mobile_Det using the [`AutoModel`]. +The example below demonstrates how to detect text with SLANet using the [`AutoModel`]. ```py -import requests +from io import BytesIO + +import httpx from PIL import Image from transformers import AutoImageProcessor, AutoModelForTableRecognition @@ -47,7 +49,7 @@ model_path="PaddlePaddle/SLANet_plus_safetensors" model = AutoModelForTableRecognition.from_pretrained(model_path, device_map="auto") image_processor = AutoImageProcessor.from_pretrained(model_path) -image = Image.open(requests.get("https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg", stream=True).raw) +image = Image.open(BytesIO(httpx.get(image_url).content)) inputs = image_processor(images=image, return_tensors="pt").to(model.device) outputs = model(**inputs) diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index cf04cb2b91d7..d3d3a501567f 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -934,7 +934,6 @@ ("seggpt", {"pil": "SegGptImageProcessorPil", "torchvision": "SegGptImageProcessor"}), ("siglip", {"pil": "SiglipImageProcessorPil", "torchvision": "SiglipImageProcessor"}), ("siglip2", {"pil": "Siglip2ImageProcessorPil", "torchvision": "Siglip2ImageProcessor"}), - ("slanet", {"torchvision": "SLANeXtImageProcessor"}), ("slanext", {"torchvision": "SLANeXtImageProcessor"}), ("smolvlm", {"pil": "SmolVLMImageProcessorPil", "torchvision": "SmolVLMImageProcessor"}), ("superglue", {"pil": "SuperGlueImageProcessorPil", "torchvision": "SuperGlueImageProcessor"}), diff --git a/src/transformers/models/slanet/configuration_slanet.py b/src/transformers/models/slanet/configuration_slanet.py index 815615684947..45c78c25022a 100644 --- a/src/transformers/models/slanet/configuration_slanet.py +++ b/src/transformers/models/slanet/configuration_slanet.py @@ -41,9 +41,9 @@ class SLANetConfig(PreTrainedConfig): max_text_length (`int`, *optional*, defaults to 500): Maximum number of autoregressive decoding steps (tokens) for the structure and location decoder. csp_kernel_size (`int`, *optional*, defaults to 5): - The kernel size of the CSP layer. - csp_blocks_num (`int`, *optional*, defaults to 1): - Number of the CSP layer. + The kernel size of the Cross Stage Partial (CSP) layer. + csp_num_blocks (`int`, *optional*, defaults to 1): + Number of the Cross Stage Partial (CSP) layer. """ model_type = "slanet" @@ -57,7 +57,7 @@ class SLANetConfig(PreTrainedConfig): hidden_act: str = "hardswish" csp_kernel_size: int = 5 - csp_blocks_num: int = 1 + csp_num_blocks: int = 1 def __post_init__(self, **kwargs): self.backbone_config, kwargs = consolidate_backbone_kwargs_to_config( diff --git a/src/transformers/models/slanet/modeling_slanet.py b/src/transformers/models/slanet/modeling_slanet.py index a02978ffbf0d..156c9cd0294c 100644 --- a/src/transformers/models/slanet/modeling_slanet.py +++ b/src/transformers/models/slanet/modeling_slanet.py @@ -30,6 +30,7 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN from ...backbone_utils import filter_output_hidden_states, load_backbone +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithNoAttention from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack @@ -41,10 +42,10 @@ class SLANetPreTrainedModel(PreTrainedModel): config: SLANetConfig - base_model_prefix = "slanet" + base_model_prefix = "backbone" main_input_name = "pixel_values" input_modalities = ("image",) - supports_gradient_checkpointing = False + supports_gradient_checkpointing = True _keep_in_fp32_modules_strict = [] @torch.no_grad() @@ -217,7 +218,7 @@ def forward(self, input: Tensor) -> Tensor: return hidden_state -class SLANetDepthwiseSeparableConvLayer(nn.Module): +class SLANetDepthwiseSeparableConvLayer(GradientCheckpointingLayer): """ Depthwise Separable Convolution Layer: Depthwise Conv -> Pointwise Conv Core component of lightweight models (e.g., MobileNet, PP-LCNet) that significantly reduces @@ -230,7 +231,7 @@ def __init__( out_channels, stride, kernel_size, - activation, + config, ): super().__init__() self.depthwise_convolution = SLANetConvLayer( @@ -239,21 +240,23 @@ def __init__( kernel_size=kernel_size, stride=stride, groups=in_channels, - activation=activation, + activation=config.hidden_act, ) + self.squeeze_excitation_module = nn.Identity() self.pointwise_convolution = SLANetConvLayer( in_channels=in_channels, kernel_size=1, out_channels=out_channels, stride=1, - activation=activation, + activation=config.hidden_act, ) - def forward(self, hidden_states): - hidden_states = self.depthwise_convolution(hidden_states) - hidden_states = self.pointwise_convolution(hidden_states) + def forward(self, hidden_state): + hidden_state = self.depthwise_convolution(hidden_state) + hidden_state = self.squeeze_excitation_module(hidden_state) + hidden_state = self.pointwise_convolution(hidden_state) - return hidden_states + return hidden_state class SLANetBottleneck(nn.Module): @@ -263,6 +266,7 @@ def __init__( out_channels, kernel_size, activation, + config, ): super().__init__() self.conv1 = SLANetConvLayer( @@ -273,10 +277,10 @@ def __init__( out_channels=out_channels, kernel_size=kernel_size, stride=1, - activation=activation, + config=config, ) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: hidden_states = self.conv1(hidden_states) hidden_states = self.conv2(hidden_states) @@ -285,66 +289,44 @@ def forward(self, hidden_states): class SLANetCSPLayer(nn.Module): """ - Cross Stage Partial (CSP) network layer. + Cross Stage Partial (CSP) network layer. Similar in structure to DFineCSPRepLayer, but with a different forward computation. """ def __init__( self, + config, in_channels, out_channels, kernel_size=3, - expand_ratio=0.5, + expansion=0.5, num_blocks=1, activation="hardswish", ): super().__init__() - mid_channels = int(out_channels * expand_ratio) - self.main_conv = SLANetConvLayer(in_channels, mid_channels, 1, activation=activation) - self.short_conv = SLANetConvLayer(in_channels, mid_channels, 1, activation=activation) - self.final_conv = SLANetConvLayer(2 * mid_channels, out_channels, 1, activation=activation) - - self.blocks = nn.ModuleList() - for _ in range(num_blocks): - self.blocks.append( - SLANetBottleneck( - mid_channels, - mid_channels, - kernel_size, - activation, - ) - ) + hidden_channels = int(out_channels * expansion) + self.conv1 = SLANetConvLayer(in_channels, hidden_channels, 1, activation=activation) + self.conv2 = SLANetConvLayer(in_channels, hidden_channels, 1, activation=activation) + self.conv3 = SLANetConvLayer(2 * hidden_channels, out_channels, 1, activation=activation) + self.bottlenecks = nn.ModuleList( + [ + SLANetBottleneck(hidden_channels, hidden_channels, kernel_size, activation, config) + for _ in range(num_blocks) + ] + ) - def forward(self, hidden_states): - hidden_states_short = self.short_conv(hidden_states) + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + residual = self.conv1(hidden_states) - hidden_states_main = self.main_conv(hidden_states) - for block in self.blocks: - hidden_states_main = block(hidden_states_main) + hidden_states = self.conv2(hidden_states) + for bottleneck in self.bottlenecks: + hidden_states = bottleneck(hidden_states) - hidden_states = torch.cat((hidden_states_main, hidden_states_short), dim=1) - hidden_states = self.final_conv(hidden_states) + hidden_states = torch.cat((hidden_states, residual), dim=1) + hidden_states = self.conv3(hidden_states) return hidden_states -class SLANetChannelProjector(nn.Module): - def __init__(self, in_channel_list, out_channels, activation): - super().__init__() - self.layers = nn.ModuleList() - for i in range(len(in_channel_list)): - self.layers.append( - SLANetConvLayer( - in_channels=in_channel_list[i], out_channels=out_channels, kernel_size=1, activation=activation - ) - ) - - def forward(self, hidden_states): - projected_features = [] - for idx in range(len(self.layers)): - projected_features.append(self.layers[idx](hidden_states[idx])) - return projected_features - - class SLANetCSPPAN(nn.Module): """ CSP-PAN: Path Aggregation Network with CSP layers @@ -359,49 +341,64 @@ def __init__( out_channels = config.post_conv_out_channels activation = config.hidden_act kernel_size = config.csp_kernel_size - csp_blocks_num = config.csp_blocks_num + csp_num_blocks = config.csp_num_blocks - self.channel_projector = SLANetChannelProjector(in_channel_list, out_channels, activation) + self.channel_projector = nn.ModuleList( + [ + SLANetConvLayer( + in_channels=in_channel_list[i], out_channels=out_channels, kernel_size=1, activation=activation + ) + for i in range(len(in_channel_list)) + ] + ) # build top-down blocks self.upsample = nn.Upsample(scale_factor=2, mode="nearest") - self.top_down_blocks = nn.ModuleList() - for _ in range(len(in_channel_list) - 1, 0, -1): - self.top_down_blocks.append( + self.top_down_blocks = nn.ModuleList( + [ SLANetCSPLayer( + config, out_channels * 2, out_channels, kernel_size=kernel_size, - num_blocks=csp_blocks_num, + num_blocks=csp_num_blocks, activation=activation, ) - ) + for _ in range(len(in_channel_list) - 1, 0, -1) + ] + ) # build bottom-up blocks - self.downsamples = nn.ModuleList() - self.bottom_up_blocks = nn.ModuleList() - for _ in range(len(in_channel_list) - 1): - self.downsamples.append( + self.downsamples = nn.ModuleList( + [ SLANetDepthwiseSeparableConvLayer( out_channels, out_channels, kernel_size=kernel_size, stride=2, - activation=activation, + config=config, ) - ) - self.bottom_up_blocks.append( + for _ in range(len(in_channel_list) - 1) + ] + ) + self.bottom_up_blocks = nn.ModuleList( + [ SLANetCSPLayer( + config, out_channels * 2, out_channels, kernel_size=kernel_size, - num_blocks=csp_blocks_num, + num_blocks=csp_num_blocks, activation=activation, ) - ) + for _ in range(len(in_channel_list) - 1) + ] + ) - def forward(self, hidden_states): - projected_features = self.channel_projector(hidden_states) + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + projected_features = [] + for idx in range(len(self.channel_projector)): + projected_features.append(self.channel_projector[idx](hidden_states[idx])) top_down_features = [projected_features[-1]] for top_down_block, low_level_feature in zip(self.top_down_blocks, reversed(projected_features[:-1])): @@ -426,11 +423,11 @@ def forward(self, hidden_states): return hidden_states -class SLANetModel(SLANetPreTrainedModel): +class SLANetBackbone(SLANetPreTrainedModel): def __init__(self, config: SLANetConfig): super().__init__(config) - self.backbone = load_backbone(config) - self.neck = SLANetCSPPAN(self.backbone.num_features[2:], config) + self.vision_backbone = load_backbone(config) + self.post_csp_pan = SLANetCSPPAN(self.vision_backbone.num_features[2:], config) self.post_init() @@ -439,8 +436,8 @@ def __init__(self, config: SLANetConfig): def forward( self, hidden_states: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple[torch.FloatTensor] | SLANetForTableRecognitionOutput: - outputs = self.backbone(hidden_states, **kwargs) - hidden_states = self.neck(outputs.feature_maps) + outputs = self.vision_backbone(hidden_states, **kwargs) + hidden_states = self.post_csp_pan(outputs.feature_maps) return BaseModelOutputWithNoAttention( last_hidden_state=hidden_states, hidden_states=outputs.hidden_states, @@ -458,7 +455,7 @@ class SLANetForTableRecognition(SLANetPreTrainedModel): def __init__(self, config: SLANetConfig): super().__init__(config) - self.model = SLANetModel(config=config) + self.backbone = SLANetBackbone(config=config) self.head = SLANetSLAHead(config=config) self.post_init() @@ -467,8 +464,9 @@ def __init__(self, config: SLANetConfig): def forward( self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple[torch.FloatTensor] | SLANetForTableRecognitionOutput: - outputs = self.model(pixel_values, **kwargs) + outputs = self.backbone(pixel_values, **kwargs) head_outputs = self.head(outputs.last_hidden_state, **kwargs) + # Key difference: no attentions in its vision model return SLANetForTableRecognitionOutput( last_hidden_state=head_outputs.last_hidden_state, hidden_states=outputs.hidden_states, @@ -477,4 +475,4 @@ def forward( ) -__all__ = ["SLANetForTableRecognition", "SLANetPreTrainedModel", "SLANetSLAHead", "SLANetModel"] +__all__ = ["SLANetForTableRecognition", "SLANetPreTrainedModel", "SLANetSLAHead", "SLANetBackbone"] diff --git a/src/transformers/models/slanet/modular_slanet.py b/src/transformers/models/slanet/modular_slanet.py index 6e0c2767fad3..d99ece4d7c6f 100644 --- a/src/transformers/models/slanet/modular_slanet.py +++ b/src/transformers/models/slanet/modular_slanet.py @@ -30,9 +30,10 @@ from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoConfig -from ..pp_lcnet.modeling_pp_lcnet import PPLCNetConvLayer +from ..pp_lcnet.modeling_pp_lcnet import PPLCNetConvLayer, PPLCNetDepthwiseSeparableConvLayer from ..slanext.configuration_slanext import SLANeXtConfig from ..slanext.modeling_slanext import ( + SLANeXtForTableRecognition, SLANeXtForTableRecognitionOutput, SLANeXtPreTrainedModel, SLANeXtSLAHead, @@ -56,9 +57,9 @@ class SLANetConfig(SLANeXtConfig): max_text_length (`int`, *optional*, defaults to 500): Maximum number of autoregressive decoding steps (tokens) for the structure and location decoder. csp_kernel_size (`int`, *optional*, defaults to 5): - The kernel size of the CSP layer. - csp_blocks_num (`int`, *optional*, defaults to 1): - Number of the CSP layer. + The kernel size of the Cross Stage Partial (CSP) layer. + csp_num_blocks (`int`, *optional*, defaults to 1): + Number of the Cross Stage Partial (CSP) layer. """ sub_configs = {"backbone_config": AutoConfig} @@ -68,13 +69,11 @@ class SLANetConfig(SLANeXtConfig): post_conv_in_channels = AttributeError() post_conv_out_channels: int = 96 - out_channels: int = 50 hidden_size: int = 256 - max_text_length: int = 500 hidden_act: str = "hardswish" csp_kernel_size: int = 5 - csp_blocks_num: int = 1 + csp_num_blocks: int = 1 def __post_init__(self, **kwargs): self.backbone_config, kwargs = consolidate_backbone_kwargs_to_config( @@ -92,8 +91,6 @@ def __post_init__(self, **kwargs): class SLANetPreTrainedModel(SLANeXtPreTrainedModel): - base_model_prefix = "slanet" - supports_gradient_checkpointing = False _keep_in_fp32_modules_strict = [] @torch.no_grad() @@ -135,7 +132,7 @@ class SLANetConvLayer(PPLCNetConvLayer): pass -class SLANetDepthwiseSeparableConvLayer(nn.Module): +class SLANetDepthwiseSeparableConvLayer(PPLCNetDepthwiseSeparableConvLayer): """ Depthwise Separable Convolution Layer: Depthwise Conv -> Pointwise Conv Core component of lightweight models (e.g., MobileNet, PP-LCNet) that significantly reduces @@ -148,30 +145,10 @@ def __init__( out_channels, stride, kernel_size, - activation, + config, ): super().__init__() - self.depthwise_convolution = SLANetConvLayer( - in_channels=in_channels, - out_channels=in_channels, - kernel_size=kernel_size, - stride=stride, - groups=in_channels, - activation=activation, - ) - self.pointwise_convolution = SLANetConvLayer( - in_channels=in_channels, - kernel_size=1, - out_channels=out_channels, - stride=1, - activation=activation, - ) - - def forward(self, hidden_states): - hidden_states = self.depthwise_convolution(hidden_states) - hidden_states = self.pointwise_convolution(hidden_states) - - return hidden_states + self.squeeze_excitation_module = nn.Identity() class SLANetBottleneck(nn.Module): @@ -181,6 +158,7 @@ def __init__( out_channels, kernel_size, activation, + config, ): super().__init__() self.conv1 = SLANetConvLayer( @@ -191,10 +169,10 @@ def __init__( out_channels=out_channels, kernel_size=kernel_size, stride=1, - activation=activation, + config=config, ) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: hidden_states = self.conv1(hidden_states) hidden_states = self.conv2(hidden_states) @@ -203,66 +181,44 @@ def forward(self, hidden_states): class SLANetCSPLayer(nn.Module): """ - Cross Stage Partial (CSP) network layer. + Cross Stage Partial (CSP) network layer. Similar in structure to DFineCSPRepLayer, but with a different forward computation. """ def __init__( self, + config, in_channels, out_channels, kernel_size=3, - expand_ratio=0.5, + expansion=0.5, num_blocks=1, activation="hardswish", ): super().__init__() - mid_channels = int(out_channels * expand_ratio) - self.main_conv = SLANetConvLayer(in_channels, mid_channels, 1, activation=activation) - self.short_conv = SLANetConvLayer(in_channels, mid_channels, 1, activation=activation) - self.final_conv = SLANetConvLayer(2 * mid_channels, out_channels, 1, activation=activation) - - self.blocks = nn.ModuleList() - for _ in range(num_blocks): - self.blocks.append( - SLANetBottleneck( - mid_channels, - mid_channels, - kernel_size, - activation, - ) - ) + hidden_channels = int(out_channels * expansion) + self.conv1 = SLANetConvLayer(in_channels, hidden_channels, 1, activation=activation) + self.conv2 = SLANetConvLayer(in_channels, hidden_channels, 1, activation=activation) + self.conv3 = SLANetConvLayer(2 * hidden_channels, out_channels, 1, activation=activation) + self.bottlenecks = nn.ModuleList( + [ + SLANetBottleneck(hidden_channels, hidden_channels, kernel_size, activation, config) + for _ in range(num_blocks) + ] + ) - def forward(self, hidden_states): - hidden_states_short = self.short_conv(hidden_states) + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + residual = self.conv1(hidden_states) - hidden_states_main = self.main_conv(hidden_states) - for block in self.blocks: - hidden_states_main = block(hidden_states_main) + hidden_states = self.conv2(hidden_states) + for bottleneck in self.bottlenecks: + hidden_states = bottleneck(hidden_states) - hidden_states = torch.cat((hidden_states_main, hidden_states_short), dim=1) - hidden_states = self.final_conv(hidden_states) + hidden_states = torch.cat((hidden_states, residual), dim=1) + hidden_states = self.conv3(hidden_states) return hidden_states -class SLANetChannelProjector(nn.Module): - def __init__(self, in_channel_list, out_channels, activation): - super().__init__() - self.layers = nn.ModuleList() - for i in range(len(in_channel_list)): - self.layers.append( - SLANetConvLayer( - in_channels=in_channel_list[i], out_channels=out_channels, kernel_size=1, activation=activation - ) - ) - - def forward(self, hidden_states): - projected_features = [] - for idx in range(len(self.layers)): - projected_features.append(self.layers[idx](hidden_states[idx])) - return projected_features - - class SLANetCSPPAN(nn.Module): """ CSP-PAN: Path Aggregation Network with CSP layers @@ -277,49 +233,64 @@ def __init__( out_channels = config.post_conv_out_channels activation = config.hidden_act kernel_size = config.csp_kernel_size - csp_blocks_num = config.csp_blocks_num + csp_num_blocks = config.csp_num_blocks - self.channel_projector = SLANetChannelProjector(in_channel_list, out_channels, activation) + self.channel_projector = nn.ModuleList( + [ + SLANetConvLayer( + in_channels=in_channel_list[i], out_channels=out_channels, kernel_size=1, activation=activation + ) + for i in range(len(in_channel_list)) + ] + ) # build top-down blocks self.upsample = nn.Upsample(scale_factor=2, mode="nearest") - self.top_down_blocks = nn.ModuleList() - for _ in range(len(in_channel_list) - 1, 0, -1): - self.top_down_blocks.append( + self.top_down_blocks = nn.ModuleList( + [ SLANetCSPLayer( + config, out_channels * 2, out_channels, kernel_size=kernel_size, - num_blocks=csp_blocks_num, + num_blocks=csp_num_blocks, activation=activation, ) - ) + for _ in range(len(in_channel_list) - 1, 0, -1) + ] + ) # build bottom-up blocks - self.downsamples = nn.ModuleList() - self.bottom_up_blocks = nn.ModuleList() - for _ in range(len(in_channel_list) - 1): - self.downsamples.append( + self.downsamples = nn.ModuleList( + [ SLANetDepthwiseSeparableConvLayer( out_channels, out_channels, kernel_size=kernel_size, stride=2, - activation=activation, + config=config, ) - ) - self.bottom_up_blocks.append( + for _ in range(len(in_channel_list) - 1) + ] + ) + self.bottom_up_blocks = nn.ModuleList( + [ SLANetCSPLayer( + config, out_channels * 2, out_channels, kernel_size=kernel_size, - num_blocks=csp_blocks_num, + num_blocks=csp_num_blocks, activation=activation, ) - ) + for _ in range(len(in_channel_list) - 1) + ] + ) - def forward(self, hidden_states): - projected_features = self.channel_projector(hidden_states) + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + projected_features = [] + for idx in range(len(self.channel_projector)): + projected_features.append(self.channel_projector[idx](hidden_states[idx])) top_down_features = [projected_features[-1]] for top_down_block, low_level_feature in zip(self.top_down_blocks, reversed(projected_features[:-1])): @@ -344,11 +315,11 @@ def forward(self, hidden_states): return hidden_states -class SLANetModel(SLANetPreTrainedModel): +class SLANetBackbone(SLANetPreTrainedModel): def __init__(self, config: SLANetConfig): super().__init__(config) - self.backbone = load_backbone(config) - self.neck = SLANetCSPPAN(self.backbone.num_features[2:], config) + self.vision_backbone = load_backbone(config) + self.post_csp_pan = SLANetCSPPAN(self.vision_backbone.num_features[2:], config) self.post_init() @@ -357,8 +328,8 @@ def __init__(self, config: SLANetConfig): def forward( self, hidden_states: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple[torch.FloatTensor] | SLANetForTableRecognitionOutput: - outputs = self.backbone(hidden_states, **kwargs) - hidden_states = self.neck(outputs.feature_maps) + outputs = self.vision_backbone(hidden_states, **kwargs) + hidden_states = self.post_csp_pan(outputs.feature_maps) return BaseModelOutputWithNoAttention( last_hidden_state=hidden_states, hidden_states=outputs.hidden_states, @@ -371,22 +342,17 @@ def forward( and returns outputs compatible with the Transformers table recognition API. """ ) -class SLANetForTableRecognition(SLANetPreTrainedModel): +class SLANetForTableRecognition(SLANeXtForTableRecognition): _keys_to_ignore_on_load_missing = ["num_batches_tracked"] - def __init__(self, config: SLANetConfig): - super().__init__(config) - self.model = SLANetModel(config=config) - self.head = SLANetSLAHead(config=config) - self.post_init() - @can_return_tuple @auto_docstring def forward( self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple[torch.FloatTensor] | SLANetForTableRecognitionOutput: - outputs = self.model(pixel_values, **kwargs) + outputs = self.backbone(pixel_values, **kwargs) head_outputs = self.head(outputs.last_hidden_state, **kwargs) + # Key difference: no attentions in its vision model return SLANetForTableRecognitionOutput( last_hidden_state=head_outputs.last_hidden_state, hidden_states=outputs.hidden_states, @@ -395,4 +361,4 @@ def forward( ) -__all__ = ["SLANetConfig", "SLANetForTableRecognition", "SLANetPreTrainedModel", "SLANetSLAHead", "SLANetModel"] +__all__ = ["SLANetConfig", "SLANetForTableRecognition", "SLANetPreTrainedModel", "SLANetSLAHead", "SLANetBackbone"] diff --git a/tests/models/slanet/test_modeling_slanet.py b/tests/models/slanet/test_modeling_slanet.py index cbb7c2eb3ac6..b92fa56428c1 100644 --- a/tests/models/slanet/test_modeling_slanet.py +++ b/tests/models/slanet/test_modeling_slanet.py @@ -17,16 +17,14 @@ import inspect import unittest -import requests - from transformers import ( AutoImageProcessor, AutoModelForTableRecognition, SLANetConfig, SLANetForTableRecognition, is_torch_available, - is_vision_available, ) +from transformers.image_utils import load_image from transformers.testing_utils import ( require_torch, require_vision, @@ -37,14 +35,12 @@ from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor from ...test_pipeline_mixin import PipelineTesterMixin +from ...test_processing_common import url_to_local_path if is_torch_available(): import torch -if is_vision_available(): - from PIL import Image - class SLANetModelTester: def __init__( @@ -165,6 +161,7 @@ def test_forward_signature(self): expected_arg_names = ["pixel_values"] self.assertListEqual(arg_names[:1], expected_arg_names) + # SLANet have no seq_length def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): model = model_class(config) @@ -200,8 +197,10 @@ def setUp(self): model_path = "PaddlePaddle/SLANet_plus_safetensors" self.model = AutoModelForTableRecognition.from_pretrained(model_path, dtype=torch.float32).to(torch_device) self.image_processor = AutoImageProcessor.from_pretrained(model_path) - url = "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg" - self.image = Image.open(requests.get(url, stream=True).raw) + img_url = url_to_local_path( + "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg" + ) + self.image = load_image(img_url) def test_inference_table_recognition_head(self): inputs = self.image_processor(images=self.image, return_tensors="pt").to(torch_device) diff --git a/utils/fetch_hub_objects_for_ci.py b/utils/fetch_hub_objects_for_ci.py index 3d229637df70..3b3609d21ce3 100644 --- a/utils/fetch_hub_objects_for_ci.py +++ b/utils/fetch_hub_objects_for_ci.py @@ -39,6 +39,7 @@ URLS_FOR_TESTING_DATA = [ # TODO: copy those to our hf-internal-testing dataset and fix all tests using them + "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg", "http://images.cocodataset.org/val2017/000000000139.jpg", "http://images.cocodataset.org/val2017/000000000285.jpg", "http://images.cocodataset.org/val2017/000000000632.jpg", From 3cd985bf0fa21c72870364c42512ae3bbc1672fc Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Tue, 21 Apr 2026 14:34:53 +0800 Subject: [PATCH 4/6] fix ci --- docs/source/en/model_doc/slanet.md | 4 ++-- utils/check_repo.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/en/model_doc/slanet.md b/docs/source/en/model_doc/slanet.md index 6ea542dd1412..92c01b9fb2ce 100644 --- a/docs/source/en/model_doc/slanet.md +++ b/docs/source/en/model_doc/slanet.md @@ -70,9 +70,9 @@ print(result['structure_score']) [[autodoc]] SLANetForTableRecognition -## SLANetModel +## SLANetBackbone -[[autodoc]] SLANetModel +[[autodoc]] SLANetBackbone ## SLANetSLAHead diff --git a/utils/check_repo.py b/utils/check_repo.py index 89b1e9e2dc10..5a5e4cea1c74 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -195,7 +195,7 @@ "SeamlessM4TTextToUnitForConditionalGeneration", # Building part of bigger (tested) model. "ChameleonVQVAE", # VQVAE here is used only for encoding (discretizing) and is tested as part of bigger model "SLANetSLAHead", # Buildinere is used only for encoding (discretizing) and is tested as part of bigger model - "SLANetModel", # Buildinere is used only for encoding (discretizing) and is tested as part of bigger model + "SLANetBackbone", # Buildinere is used only for encoding (discretizing) and is tested as part of bigger model "SLANeXtSLAHead", # Buildinere is used only for encoding (discretizing) and is tested as part of bigger model "SLANeXtBackbone", # Building part of bigger (tested) model. Tested implicitly through SLANeXtForTableRecognition. "PPOCRV5MobileDetModel", # Building part of bigger (tested) model. Tested implicitly through PPOCRV5MobileDetForObjectDetection. @@ -463,7 +463,7 @@ "JanusVQVAE", # no autoclass for VQ-VAE models "JanusVisionModel", # Building part of bigger (tested) model "SLANetSLAHead", # Building part of bigger (tested) model - "SLANetModel", # Building part of bigger (tested) model + "SLANetBackbone", # Building part of bigger (tested) model "SLANeXtSLAHead", # Building part of bigger (tested) model "SLANeXtBackbone", # Building part of bigger (tested) model "PPOCRV5MobileDetModel", # Building part of bigger (tested) model From c63bdae5539f9c6a34762ad7740a5cec533fb6b3 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Wed, 22 Apr 2026 11:15:04 +0800 Subject: [PATCH 5/6] update --- .../models/auto/image_processing_auto.py | 1 + .../models/slanet/configuration_slanet.py | 2 +- .../models/slanet/modeling_slanet.py | 12 ++++---- .../models/slanet/modular_slanet.py | 29 ++++++++++++------- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 14db6e495478..0251a8b7c917 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -115,6 +115,7 @@ ("pixio", {"torchvision": "BitImageProcessor", "pil": "BitImageProcessorPil"}), ("pp_ocrv5_mobile_det", {"torchvision": "PPOCRV5ServerDetImageProcessor"}), ("pp_ocrv5_mobile_rec", {"torchvision": "PPOCRV5ServerRecImageProcessor"}), + ("slanet", {"torchvision": "SLANeXtImageProcessor"}), ("pvt_v2", {"torchvision": "PvtImageProcessor", "pil": "PvtImageProcessorPil"}), ("qianfan_ocr", {"torchvision": "GotOcr2ImageProcessor", "pil": "GotOcr2ImageProcessorPil"}), ("qwen2_5_omni", {"torchvision": "Qwen2VLImageProcessor", "pil": "Qwen2VLImageProcessorPil"}), diff --git a/src/transformers/models/slanet/configuration_slanet.py b/src/transformers/models/slanet/configuration_slanet.py index 45c78c25022a..74e27df7fbc2 100644 --- a/src/transformers/models/slanet/configuration_slanet.py +++ b/src/transformers/models/slanet/configuration_slanet.py @@ -43,7 +43,7 @@ class SLANetConfig(PreTrainedConfig): csp_kernel_size (`int`, *optional*, defaults to 5): The kernel size of the Cross Stage Partial (CSP) layer. csp_num_blocks (`int`, *optional*, defaults to 1): - Number of the Cross Stage Partial (CSP) layer. + Number of blocks within the Cross Stage Partial (CSP) layer. """ model_type = "slanet" diff --git a/src/transformers/models/slanet/modeling_slanet.py b/src/transformers/models/slanet/modeling_slanet.py index 156c9cd0294c..8ca95ad53d05 100644 --- a/src/transformers/models/slanet/modeling_slanet.py +++ b/src/transformers/models/slanet/modeling_slanet.py @@ -77,7 +77,7 @@ def _init_weights(self, module): @dataclass @auto_docstring -class SLANetForTableRecognitionOutput(BaseModelOutput): +class SLANetForTableRecognitionOutput(BaseModelOutputWithNoAttention): r""" head_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Hidden-states of the SLANetSLAHead at each prediction step, varies up to max `self.config.max_text_length` states (depending on early exits). @@ -334,8 +334,8 @@ class SLANetCSPPAN(nn.Module): def __init__( self, - in_channel_list, config, + in_channel_list, ): super().__init__() out_channels = config.post_conv_out_channels @@ -427,15 +427,15 @@ class SLANetBackbone(SLANetPreTrainedModel): def __init__(self, config: SLANetConfig): super().__init__(config) self.vision_backbone = load_backbone(config) - self.post_csp_pan = SLANetCSPPAN(self.vision_backbone.num_features[2:], config) + self.post_csp_pan = SLANetCSPPAN(config, self.vision_backbone.num_features[2:]) self.post_init() - @merge_with_config_defaults - @capture_outputs + @can_return_tuple + @auto_docstring def forward( self, hidden_states: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs] - ) -> tuple[torch.FloatTensor] | SLANetForTableRecognitionOutput: + ) -> tuple[torch.FloatTensor] | BaseModelOutputWithNoAttention: outputs = self.vision_backbone(hidden_states, **kwargs) hidden_states = self.post_csp_pan(outputs.feature_maps) return BaseModelOutputWithNoAttention( diff --git a/src/transformers/models/slanet/modular_slanet.py b/src/transformers/models/slanet/modular_slanet.py index d99ece4d7c6f..0e9d32e2524e 100644 --- a/src/transformers/models/slanet/modular_slanet.py +++ b/src/transformers/models/slanet/modular_slanet.py @@ -14,6 +14,7 @@ import math +from dataclasses import dataclass import torch import torch.nn as nn @@ -27,14 +28,11 @@ from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import merge_with_config_defaults -from ...utils.output_capturing import capture_outputs from ..auto import AutoConfig from ..pp_lcnet.modeling_pp_lcnet import PPLCNetConvLayer, PPLCNetDepthwiseSeparableConvLayer from ..slanext.configuration_slanext import SLANeXtConfig from ..slanext.modeling_slanext import ( SLANeXtForTableRecognition, - SLANeXtForTableRecognitionOutput, SLANeXtPreTrainedModel, SLANeXtSLAHead, ) @@ -59,7 +57,7 @@ class SLANetConfig(SLANeXtConfig): csp_kernel_size (`int`, *optional*, defaults to 5): The kernel size of the Cross Stage Partial (CSP) layer. csp_num_blocks (`int`, *optional*, defaults to 1): - Number of the Cross Stage Partial (CSP) layer. + Number of blocks within the Cross Stage Partial (CSP) layer. """ sub_configs = {"backbone_config": AutoConfig} @@ -119,9 +117,18 @@ def _init_weights(self, module): if layer.bias is not None: init.uniform_(layer.bias, -std, std) +@dataclass +@auto_docstring +class SLANetForTableRecognitionOutput(BaseModelOutputWithNoAttention): + r""" + head_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Hidden-states of the SLANetSLAHead at each prediction step, varies up to max `self.config.max_text_length` states (depending on early exits). + head_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Attentions of the SLANetSLAHead at each prediction step, varies up to max `self.config.max_text_length` attentions (depending on early exits). + """ -class SLANetForTableRecognitionOutput(SLANeXtForTableRecognitionOutput): - pass + head_hidden_states: torch.FloatTensor | None = None + head_attentions: torch.FloatTensor | None = None class SLANetSLAHead(SLANeXtSLAHead): @@ -226,8 +233,8 @@ class SLANetCSPPAN(nn.Module): def __init__( self, - in_channel_list, config, + in_channel_list, ): super().__init__() out_channels = config.post_conv_out_channels @@ -319,15 +326,15 @@ class SLANetBackbone(SLANetPreTrainedModel): def __init__(self, config: SLANetConfig): super().__init__(config) self.vision_backbone = load_backbone(config) - self.post_csp_pan = SLANetCSPPAN(self.vision_backbone.num_features[2:], config) + self.post_csp_pan = SLANetCSPPAN(config, self.vision_backbone.num_features[2:]) self.post_init() - @merge_with_config_defaults - @capture_outputs + @can_return_tuple + @auto_docstring def forward( self, hidden_states: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs] - ) -> tuple[torch.FloatTensor] | SLANetForTableRecognitionOutput: + ) -> tuple[torch.FloatTensor] | BaseModelOutputWithNoAttention: outputs = self.vision_backbone(hidden_states, **kwargs) hidden_states = self.post_csp_pan(outputs.feature_maps) return BaseModelOutputWithNoAttention( From 682a5b1640031b5b1a75f8b064c539dab4670950 Mon Sep 17 00:00:00 2001 From: zhangyue66 Date: Wed, 22 Apr 2026 11:22:17 +0800 Subject: [PATCH 6/6] update --- docs/source/en/model_doc/slanet.md | 2 +- src/transformers/models/auto/image_processing_auto.py | 2 +- src/transformers/models/slanet/modular_slanet.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/slanet.md b/docs/source/en/model_doc/slanet.md index 92c01b9fb2ce..9f1f684e6f1e 100644 --- a/docs/source/en/model_doc/slanet.md +++ b/docs/source/en/model_doc/slanet.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on 2025-03-07 and added to Hugging Face Transformers on 2026-04-21.* +*This model was released on 2025-03-07 and added to Hugging Face Transformers on 2026-04-22.* # SLANet diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 0251a8b7c917..c74ee27519ff 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -115,7 +115,6 @@ ("pixio", {"torchvision": "BitImageProcessor", "pil": "BitImageProcessorPil"}), ("pp_ocrv5_mobile_det", {"torchvision": "PPOCRV5ServerDetImageProcessor"}), ("pp_ocrv5_mobile_rec", {"torchvision": "PPOCRV5ServerRecImageProcessor"}), - ("slanet", {"torchvision": "SLANeXtImageProcessor"}), ("pvt_v2", {"torchvision": "PvtImageProcessor", "pil": "PvtImageProcessorPil"}), ("qianfan_ocr", {"torchvision": "GotOcr2ImageProcessor", "pil": "GotOcr2ImageProcessorPil"}), ("qwen2_5_omni", {"torchvision": "Qwen2VLImageProcessor", "pil": "Qwen2VLImageProcessorPil"}), @@ -133,6 +132,7 @@ ("sam3_video", {"torchvision": "Sam3ImageProcessor"}), ("sam_hq", {"torchvision": "SamImageProcessor", "pil": "SamImageProcessorPil"}), ("shieldgemma2", {"torchvision": "Gemma3ImageProcessor", "pil": "Gemma3ImageProcessorPil"}), + ("slanet", {"torchvision": "SLANeXtImageProcessor"}), ("swiftformer", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), ("swin", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), ("swinv2", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), diff --git a/src/transformers/models/slanet/modular_slanet.py b/src/transformers/models/slanet/modular_slanet.py index 0e9d32e2524e..19dfedb2901d 100644 --- a/src/transformers/models/slanet/modular_slanet.py +++ b/src/transformers/models/slanet/modular_slanet.py @@ -117,6 +117,7 @@ def _init_weights(self, module): if layer.bias is not None: init.uniform_(layer.bias, -std, std) + @dataclass @auto_docstring class SLANetForTableRecognitionOutput(BaseModelOutputWithNoAttention):