diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index aec6b14839cb..01a1dd2089b0 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -1365,6 +1365,8 @@
title: SigLIP
- local: model_doc/siglip2
title: SigLIP2
+ - local: model_doc/slanet
+ title: SLANet
- local: model_doc/slanext
title: SLANeXt
- local: model_doc/smollm3
diff --git a/docs/source/en/model_doc/slanet.md b/docs/source/en/model_doc/slanet.md
new file mode 100644
index 000000000000..9f1f684e6f1e
--- /dev/null
+++ b/docs/source/en/model_doc/slanet.md
@@ -0,0 +1,80 @@
+
+*This model was released on 2025-03-07 and added to Hugging Face Transformers on 2026-04-22.*
+
+# SLANet
+
+
+

+
+
+## Overview
+
+**SLANet** and **SLANet_plus** are part of a series of dedicated lightweight models for table structure recognition, focusing on accurately recognizing table structures in documents and natural scenes. For more details about the SLANet series model, please refer to the [official documentation](https://www.paddleocr.ai/latest/en/version3.x/module_usage/table_structure_recognition.html).
+
+## Model Architecture
+
+SLANet is a table structure recognition model developed by Baidu PaddlePaddle Vision Team. The model significantly improves the accuracy and inference speed of table structure recognition by adopting a CPU-friendly lightweight backbone network PP-LCNet, a high-low-level feature fusion module CSP-PAN, and a feature decoding module SLA Head that aligns structural and positional information.
+
+## Usage
+
+### Single input inference
+
+The example below demonstrates how to detect text with SLANet using the [`AutoModel`].
+
+
+
+
+```py
+from io import BytesIO
+
+import httpx
+from PIL import Image
+from transformers import AutoImageProcessor, AutoModelForTableRecognition
+
+model_path="PaddlePaddle/SLANet_plus_safetensors"
+model = AutoModelForTableRecognition.from_pretrained(model_path, device_map="auto")
+image_processor = AutoImageProcessor.from_pretrained(model_path)
+
+image = Image.open(BytesIO(httpx.get(image_url).content))
+inputs = image_processor(images=image, return_tensors="pt").to(model.device)
+outputs = model(**inputs)
+
+results = image_processor.post_process_table_recognition(outputs)
+
+print(result['structure'])
+print(result['structure_score'])
+```
+
+
+
+
+## SLANetConfig
+
+[[autodoc]] SLANetConfig
+
+## SLANetForTableRecognition
+
+[[autodoc]] SLANetForTableRecognition
+
+## SLANetBackbone
+
+[[autodoc]] SLANetBackbone
+
+## SLANetSLAHead
+
+[[autodoc]] SLANetSLAHead
+
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 8aad0af6c303..b2b991bf7dde 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -383,6 +383,7 @@
from .shieldgemma2 import *
from .siglip import *
from .siglip2 import *
+ from .slanet import *
from .slanext import *
from .smollm3 import *
from .smolvlm import *
diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py
index 10e376b65956..d3d3a501567f 100644
--- a/src/transformers/models/auto/auto_mappings.py
+++ b/src/transformers/models/auto/auto_mappings.py
@@ -529,6 +529,7 @@
("siglip2_vision_model", "Siglip2VisionConfig"),
("siglip_text_model", "SiglipTextConfig"),
("siglip_vision_model", "SiglipVisionConfig"),
+ ("slanet", "SLANetConfig"),
("slanext", "SLANeXtConfig"),
("smollm3", "SmolLM3Config"),
("smolvlm", "SmolVLMConfig"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 14db6e495478..c74ee27519ff 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -132,6 +132,7 @@
("sam3_video", {"torchvision": "Sam3ImageProcessor"}),
("sam_hq", {"torchvision": "SamImageProcessor", "pil": "SamImageProcessorPil"}),
("shieldgemma2", {"torchvision": "Gemma3ImageProcessor", "pil": "Gemma3ImageProcessorPil"}),
+ ("slanet", {"torchvision": "SLANeXtImageProcessor"}),
("swiftformer", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}),
("swin", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}),
("swinv2", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index deb1153d335e..e7a3fb0b1599 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -1157,6 +1157,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
MODEL_FOR_TABLE_RECOGNITION_MAPPING_NAMES = OrderedDict(
[
+ ("slanet", "SLANetForTableRecognition"),
("slanext", "SLANeXtForTableRecognition"),
]
)
diff --git a/src/transformers/models/slanet/__init__.py b/src/transformers/models/slanet/__init__.py
new file mode 100644
index 000000000000..96d521f4972f
--- /dev/null
+++ b/src/transformers/models/slanet/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_slanet import *
+ from .modeling_slanet import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/slanet/configuration_slanet.py b/src/transformers/models/slanet/configuration_slanet.py
new file mode 100644
index 000000000000..74e27df7fbc2
--- /dev/null
+++ b/src/transformers/models/slanet/configuration_slanet.py
@@ -0,0 +1,77 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/slanet/modular_slanet.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_slanet.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from huggingface_hub.dataclasses import strict
+
+from ...backbone_utils import consolidate_backbone_kwargs_to_config
+from ...configuration_utils import PreTrainedConfig
+from ...utils import auto_docstring
+from ..auto import AutoConfig
+
+
+@auto_docstring(checkpoint="PaddlePaddle/SLANet_plus_safetensors")
+@strict
+class SLANetConfig(PreTrainedConfig):
+ r"""
+ post_conv_out_channels (`int`, *optional*, defaults to 96):
+ Number of output channels for the post-encoder convolution layer.
+ out_channels (`int`, *optional*, defaults to 50):
+ Vocabulary size for the table structure token prediction head, i.e., the number of distinct structure
+ tokens the model can predict.
+ hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the hidden states in the attention GRU cell and the structure/location prediction heads.
+ max_text_length (`int`, *optional*, defaults to 500):
+ Maximum number of autoregressive decoding steps (tokens) for the structure and location decoder.
+ csp_kernel_size (`int`, *optional*, defaults to 5):
+ The kernel size of the Cross Stage Partial (CSP) layer.
+ csp_num_blocks (`int`, *optional*, defaults to 1):
+ Number of blocks within the Cross Stage Partial (CSP) layer.
+ """
+
+ model_type = "slanet"
+
+ sub_configs = {"backbone_config": AutoConfig}
+ post_conv_out_channels: int = 96
+ out_channels: int = 50
+ hidden_size: int = 256
+ max_text_length: int = 500
+ backbone_config: dict | PreTrainedConfig | None = None
+
+ hidden_act: str = "hardswish"
+ csp_kernel_size: int = 5
+ csp_num_blocks: int = 1
+
+ def __post_init__(self, **kwargs):
+ self.backbone_config, kwargs = consolidate_backbone_kwargs_to_config(
+ backbone_config=self.backbone_config,
+ default_config_type="pp_lcnet",
+ default_config_kwargs={
+ "scale": 1,
+ "out_features": ["stage2", "stage3", "stage4", "stage5"],
+ "out_indices": [2, 3, 4, 5],
+ "divisor": 16,
+ },
+ **kwargs,
+ )
+ super().__post_init__(**kwargs)
+
+
+__all__ = ["SLANetConfig"]
diff --git a/src/transformers/models/slanet/modeling_slanet.py b/src/transformers/models/slanet/modeling_slanet.py
new file mode 100644
index 000000000000..8ca95ad53d05
--- /dev/null
+++ b/src/transformers/models/slanet/modeling_slanet.py
@@ -0,0 +1,478 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/slanet/modular_slanet.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_slanet.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+from ... import initialization as init
+from ...activations import ACT2CLS, ACT2FN
+from ...backbone_utils import filter_output_hidden_states, load_backbone
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithNoAttention
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.generic import merge_with_config_defaults
+from ...utils.output_capturing import capture_outputs
+from .configuration_slanet import SLANetConfig
+
+
+class SLANetPreTrainedModel(PreTrainedModel):
+ config: SLANetConfig
+ base_model_prefix = "backbone"
+ main_input_name = "pixel_values"
+ input_modalities = ("image",)
+ supports_gradient_checkpointing = True
+ _keep_in_fp32_modules_strict = []
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ super()._init_weights(module)
+
+ # Initialize GRUCell (replicates PyTorch default reset_parameters)
+ if isinstance(module, nn.GRUCell):
+ std = 1.0 / math.sqrt(module.hidden_size) if module.hidden_size > 0 else 0
+ init.uniform_(module.weight_ih, -std, std)
+ init.uniform_(module.weight_hh, -std, std)
+ if module.bias_ih is not None:
+ init.uniform_(module.bias_ih, -std, std)
+ if module.bias_hh is not None:
+ init.uniform_(module.bias_hh, -std, std)
+
+ # Initialize SLAHead layers
+ if isinstance(module, SLANetSLAHead):
+ std = 1.0 / math.sqrt(self.config.hidden_size * 1.0)
+ # Initialize structure_generator and loc_generator layers
+ for generator in (module.structure_generator,):
+ for layer in generator.children():
+ if isinstance(layer, nn.Linear):
+ init.uniform_(layer.weight, -std, std)
+ if layer.bias is not None:
+ init.uniform_(layer.bias, -std, std)
+
+
+@dataclass
+@auto_docstring
+class SLANetForTableRecognitionOutput(BaseModelOutputWithNoAttention):
+ r"""
+ head_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Hidden-states of the SLANetSLAHead at each prediction step, varies up to max `self.config.max_text_length` states (depending on early exits).
+ head_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Attentions of the SLANetSLAHead at each prediction step, varies up to max `self.config.max_text_length` attentions (depending on early exits).
+ """
+
+ head_hidden_states: torch.FloatTensor | None = None
+ head_attentions: torch.FloatTensor | None = None
+
+
+class SLANetAttentionGRUCell(nn.Module):
+ def __init__(self, input_size, hidden_size, num_embeddings):
+ super().__init__()
+
+ self.input_to_hidden = nn.Linear(input_size, hidden_size, bias=False)
+ self.hidden_to_hidden = nn.Linear(hidden_size, hidden_size)
+ self.score = nn.Linear(hidden_size, 1, bias=False)
+
+ self.rnn = nn.GRUCell(input_size + num_embeddings, hidden_size)
+
+ def forward(
+ self,
+ prev_hidden: torch.FloatTensor,
+ batch_hidden: torch.FloatTensor,
+ char_onehots: torch.FloatTensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ batch_hidden_proj = self.input_to_hidden(batch_hidden)
+ prev_hidden_proj = self.hidden_to_hidden(prev_hidden).unsqueeze(1)
+
+ attention_scores = batch_hidden_proj + prev_hidden_proj
+ attention_scores = torch.tanh(attention_scores)
+ attention_scores = self.score(attention_scores)
+
+ attn_weights = F.softmax(attention_scores, dim=1, dtype=torch.float32).to(attention_scores.dtype)
+ attn_weights = attn_weights.transpose(1, 2)
+ context = torch.matmul(attn_weights, batch_hidden).squeeze(1)
+ concat_context = torch.cat([context, char_onehots], 1)
+ hidden_states = self.rnn(concat_context, prev_hidden)
+
+ return hidden_states, attn_weights
+
+
+class SLANetMLP(nn.Module):
+ def __init__(self, hidden_size, out_channels, activation=None):
+ super().__init__()
+ self.fc1 = nn.Linear(hidden_size, hidden_size)
+ self.fc2 = nn.Linear(hidden_size, out_channels)
+ self.act_fn = nn.Identity() if activation is None else ACT2CLS[activation]()
+
+ def forward(self, hidden_states):
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ return hidden_states
+
+
+class SLANetSLAHead(SLANetPreTrainedModel):
+ _can_record_outputs = {
+ "attentions": SLANetAttentionGRUCell,
+ }
+
+ def __init__(
+ self,
+ config: dict | None = None,
+ **kwargs,
+ ):
+ super().__init__(config)
+
+ self.structure_attention_cell = SLANetAttentionGRUCell(
+ config.post_conv_out_channels, config.hidden_size, config.out_channels
+ )
+ self.structure_generator = SLANetMLP(config.hidden_size, config.out_channels)
+
+ self.post_init()
+
+ @merge_with_config_defaults
+ @capture_outputs
+ @filter_output_hidden_states
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ targets: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ features = torch.zeros(
+ (hidden_states.shape[0], self.config.hidden_size), dtype=torch.float32, device=hidden_states.device
+ )
+ predicted_chars = torch.zeros(size=[hidden_states.shape[0]], dtype=torch.long, device=hidden_states.device)
+
+ structure_preds_list = []
+ structure_ids_list = []
+ for _ in range(self.config.max_text_length + 1):
+ embedding_feature = F.one_hot(predicted_chars, self.config.out_channels).float()
+ features, _ = self.structure_attention_cell(features, hidden_states.float(), embedding_feature)
+ structure_step = self.structure_generator(features)
+ predicted_chars = structure_step.argmax(dim=1)
+
+ structure_preds_list.append(structure_step)
+ structure_ids_list.append(predicted_chars)
+ if torch.stack(structure_ids_list, dim=1).eq(self.config.out_channels - 1).any(-1).all():
+ break
+ structure_preds = F.softmax(torch.stack(structure_preds_list, dim=1), dim=-1, dtype=torch.float32).to(
+ hidden_states.dtype
+ )
+
+ return BaseModelOutput(last_hidden_state=structure_preds, hidden_states=structure_preds_list)
+
+
+class SLANetConvLayer(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ activation: str = "hardswish",
+ groups: int = 1,
+ ):
+ super().__init__()
+ self.convolution = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=kernel_size // 2,
+ bias=False,
+ groups=groups,
+ )
+ self.normalization = nn.BatchNorm2d(out_channels)
+ self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
+
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = self.convolution(input)
+ hidden_state = self.normalization(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
+
+class SLANetDepthwiseSeparableConvLayer(GradientCheckpointingLayer):
+ """
+ Depthwise Separable Convolution Layer: Depthwise Conv -> Pointwise Conv
+ Core component of lightweight models (e.g., MobileNet, PP-LCNet) that significantly reduces
+ the number of parameters and computational cost.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ kernel_size,
+ config,
+ ):
+ super().__init__()
+ self.depthwise_convolution = SLANetConvLayer(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ groups=in_channels,
+ activation=config.hidden_act,
+ )
+ self.squeeze_excitation_module = nn.Identity()
+ self.pointwise_convolution = SLANetConvLayer(
+ in_channels=in_channels,
+ kernel_size=1,
+ out_channels=out_channels,
+ stride=1,
+ activation=config.hidden_act,
+ )
+
+ def forward(self, hidden_state):
+ hidden_state = self.depthwise_convolution(hidden_state)
+ hidden_state = self.squeeze_excitation_module(hidden_state)
+ hidden_state = self.pointwise_convolution(hidden_state)
+
+ return hidden_state
+
+
+class SLANetBottleneck(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ activation,
+ config,
+ ):
+ super().__init__()
+ self.conv1 = SLANetConvLayer(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, activation=activation
+ )
+ self.conv2 = SLANetDepthwiseSeparableConvLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ config=config,
+ )
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ hidden_states = self.conv1(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ return hidden_states
+
+
+class SLANetCSPLayer(nn.Module):
+ """
+ Cross Stage Partial (CSP) network layer. Similar in structure to DFineCSPRepLayer, but with a different forward computation.
+ """
+
+ def __init__(
+ self,
+ config,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ expansion=0.5,
+ num_blocks=1,
+ activation="hardswish",
+ ):
+ super().__init__()
+ hidden_channels = int(out_channels * expansion)
+ self.conv1 = SLANetConvLayer(in_channels, hidden_channels, 1, activation=activation)
+ self.conv2 = SLANetConvLayer(in_channels, hidden_channels, 1, activation=activation)
+ self.conv3 = SLANetConvLayer(2 * hidden_channels, out_channels, 1, activation=activation)
+ self.bottlenecks = nn.ModuleList(
+ [
+ SLANetBottleneck(hidden_channels, hidden_channels, kernel_size, activation, config)
+ for _ in range(num_blocks)
+ ]
+ )
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ residual = self.conv1(hidden_states)
+
+ hidden_states = self.conv2(hidden_states)
+ for bottleneck in self.bottlenecks:
+ hidden_states = bottleneck(hidden_states)
+
+ hidden_states = torch.cat((hidden_states, residual), dim=1)
+ hidden_states = self.conv3(hidden_states)
+
+ return hidden_states
+
+
+class SLANetCSPPAN(nn.Module):
+ """
+ CSP-PAN: Path Aggregation Network with CSP layers
+ """
+
+ def __init__(
+ self,
+ config,
+ in_channel_list,
+ ):
+ super().__init__()
+ out_channels = config.post_conv_out_channels
+ activation = config.hidden_act
+ kernel_size = config.csp_kernel_size
+ csp_num_blocks = config.csp_num_blocks
+
+ self.channel_projector = nn.ModuleList(
+ [
+ SLANetConvLayer(
+ in_channels=in_channel_list[i], out_channels=out_channels, kernel_size=1, activation=activation
+ )
+ for i in range(len(in_channel_list))
+ ]
+ )
+
+ # build top-down blocks
+ self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
+ self.top_down_blocks = nn.ModuleList(
+ [
+ SLANetCSPLayer(
+ config,
+ out_channels * 2,
+ out_channels,
+ kernel_size=kernel_size,
+ num_blocks=csp_num_blocks,
+ activation=activation,
+ )
+ for _ in range(len(in_channel_list) - 1, 0, -1)
+ ]
+ )
+
+ # build bottom-up blocks
+ self.downsamples = nn.ModuleList(
+ [
+ SLANetDepthwiseSeparableConvLayer(
+ out_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=2,
+ config=config,
+ )
+ for _ in range(len(in_channel_list) - 1)
+ ]
+ )
+ self.bottom_up_blocks = nn.ModuleList(
+ [
+ SLANetCSPLayer(
+ config,
+ out_channels * 2,
+ out_channels,
+ kernel_size=kernel_size,
+ num_blocks=csp_num_blocks,
+ activation=activation,
+ )
+ for _ in range(len(in_channel_list) - 1)
+ ]
+ )
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ projected_features = []
+ for idx in range(len(self.channel_projector)):
+ projected_features.append(self.channel_projector[idx](hidden_states[idx]))
+
+ top_down_features = [projected_features[-1]]
+ for top_down_block, low_level_feature in zip(self.top_down_blocks, reversed(projected_features[:-1])):
+ high_level_feature = top_down_features[-1]
+ upsampled_feature = F.interpolate(
+ high_level_feature,
+ size=low_level_feature.shape[-2:],
+ mode="nearest",
+ )
+ fused_feature = top_down_block(torch.cat([upsampled_feature, low_level_feature], dim=1))
+ top_down_features.append(fused_feature)
+
+ pyramid_features = list(reversed(top_down_features))
+ output_feature = pyramid_features[0]
+ for downsample_layer, bottom_up_block, high_level_feature in zip(
+ self.downsamples, self.bottom_up_blocks, pyramid_features[1:]
+ ):
+ downsampled_feature = downsample_layer(output_feature)
+ output_feature = bottom_up_block(torch.cat([downsampled_feature, high_level_feature], dim=1))
+
+ hidden_states = output_feature.flatten(2).transpose(1, 2)
+ return hidden_states
+
+
+class SLANetBackbone(SLANetPreTrainedModel):
+ def __init__(self, config: SLANetConfig):
+ super().__init__(config)
+ self.vision_backbone = load_backbone(config)
+ self.post_csp_pan = SLANetCSPPAN(config, self.vision_backbone.num_features[2:])
+
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self, hidden_states: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
+ ) -> tuple[torch.FloatTensor] | BaseModelOutputWithNoAttention:
+ outputs = self.vision_backbone(hidden_states, **kwargs)
+ hidden_states = self.post_csp_pan(outputs.feature_maps)
+ return BaseModelOutputWithNoAttention(
+ last_hidden_state=hidden_states,
+ hidden_states=outputs.hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ SLANet Table Recognition model for table recognition tasks. Wraps the core SLANetPreTrainedModel
+ and returns outputs compatible with the Transformers table recognition API.
+ """
+)
+class SLANetForTableRecognition(SLANetPreTrainedModel):
+ _keys_to_ignore_on_load_missing = ["num_batches_tracked"]
+
+ def __init__(self, config: SLANetConfig):
+ super().__init__(config)
+ self.backbone = SLANetBackbone(config=config)
+ self.head = SLANetSLAHead(config=config)
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
+ ) -> tuple[torch.FloatTensor] | SLANetForTableRecognitionOutput:
+ outputs = self.backbone(pixel_values, **kwargs)
+ head_outputs = self.head(outputs.last_hidden_state, **kwargs)
+ # Key difference: no attentions in its vision model
+ return SLANetForTableRecognitionOutput(
+ last_hidden_state=head_outputs.last_hidden_state,
+ hidden_states=outputs.hidden_states,
+ head_hidden_states=head_outputs.hidden_states,
+ head_attentions=head_outputs.attentions,
+ )
+
+
+__all__ = ["SLANetForTableRecognition", "SLANetPreTrainedModel", "SLANetSLAHead", "SLANetBackbone"]
diff --git a/src/transformers/models/slanet/modular_slanet.py b/src/transformers/models/slanet/modular_slanet.py
new file mode 100644
index 000000000000..19dfedb2901d
--- /dev/null
+++ b/src/transformers/models/slanet/modular_slanet.py
@@ -0,0 +1,372 @@
+# Copyright 2026 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from huggingface_hub.dataclasses import strict
+
+from ... import initialization as init
+from ...backbone_utils import consolidate_backbone_kwargs_to_config, load_backbone
+from ...configuration_utils import PreTrainedConfig
+from ...modeling_outputs import BaseModelOutputWithNoAttention
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ..auto import AutoConfig
+from ..pp_lcnet.modeling_pp_lcnet import PPLCNetConvLayer, PPLCNetDepthwiseSeparableConvLayer
+from ..slanext.configuration_slanext import SLANeXtConfig
+from ..slanext.modeling_slanext import (
+ SLANeXtForTableRecognition,
+ SLANeXtPreTrainedModel,
+ SLANeXtSLAHead,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+@auto_docstring(checkpoint="PaddlePaddle/SLANet_plus_safetensors")
+@strict
+class SLANetConfig(SLANeXtConfig):
+ r"""
+ post_conv_out_channels (`int`, *optional*, defaults to 96):
+ Number of output channels for the post-encoder convolution layer.
+ out_channels (`int`, *optional*, defaults to 50):
+ Vocabulary size for the table structure token prediction head, i.e., the number of distinct structure
+ tokens the model can predict.
+ hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the hidden states in the attention GRU cell and the structure/location prediction heads.
+ max_text_length (`int`, *optional*, defaults to 500):
+ Maximum number of autoregressive decoding steps (tokens) for the structure and location decoder.
+ csp_kernel_size (`int`, *optional*, defaults to 5):
+ The kernel size of the Cross Stage Partial (CSP) layer.
+ csp_num_blocks (`int`, *optional*, defaults to 1):
+ Number of blocks within the Cross Stage Partial (CSP) layer.
+ """
+
+ sub_configs = {"backbone_config": AutoConfig}
+
+ vision_config = AttributeError()
+ backbone_config: dict | PreTrainedConfig | None = None
+
+ post_conv_in_channels = AttributeError()
+ post_conv_out_channels: int = 96
+ hidden_size: int = 256
+
+ hidden_act: str = "hardswish"
+ csp_kernel_size: int = 5
+ csp_num_blocks: int = 1
+
+ def __post_init__(self, **kwargs):
+ self.backbone_config, kwargs = consolidate_backbone_kwargs_to_config(
+ backbone_config=self.backbone_config,
+ default_config_type="pp_lcnet",
+ default_config_kwargs={
+ "scale": 1,
+ "out_features": ["stage2", "stage3", "stage4", "stage5"],
+ "out_indices": [2, 3, 4, 5],
+ "divisor": 16,
+ },
+ **kwargs,
+ )
+ PreTrainedConfig.__post_init__(**kwargs)
+
+
+class SLANetPreTrainedModel(SLANeXtPreTrainedModel):
+ _keep_in_fp32_modules_strict = []
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ PreTrainedModel._init_weights(module)
+
+ # Initialize GRUCell (replicates PyTorch default reset_parameters)
+ if isinstance(module, nn.GRUCell):
+ std = 1.0 / math.sqrt(module.hidden_size) if module.hidden_size > 0 else 0
+ init.uniform_(module.weight_ih, -std, std)
+ init.uniform_(module.weight_hh, -std, std)
+ if module.bias_ih is not None:
+ init.uniform_(module.bias_ih, -std, std)
+ if module.bias_hh is not None:
+ init.uniform_(module.bias_hh, -std, std)
+
+ # Initialize SLAHead layers
+ if isinstance(module, SLANetSLAHead):
+ std = 1.0 / math.sqrt(self.config.hidden_size * 1.0)
+ # Initialize structure_generator and loc_generator layers
+ for generator in (module.structure_generator,):
+ for layer in generator.children():
+ if isinstance(layer, nn.Linear):
+ init.uniform_(layer.weight, -std, std)
+ if layer.bias is not None:
+ init.uniform_(layer.bias, -std, std)
+
+
+@dataclass
+@auto_docstring
+class SLANetForTableRecognitionOutput(BaseModelOutputWithNoAttention):
+ r"""
+ head_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Hidden-states of the SLANetSLAHead at each prediction step, varies up to max `self.config.max_text_length` states (depending on early exits).
+ head_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Attentions of the SLANetSLAHead at each prediction step, varies up to max `self.config.max_text_length` attentions (depending on early exits).
+ """
+
+ head_hidden_states: torch.FloatTensor | None = None
+ head_attentions: torch.FloatTensor | None = None
+
+
+class SLANetSLAHead(SLANeXtSLAHead):
+ pass
+
+
+class SLANetConvLayer(PPLCNetConvLayer):
+ pass
+
+
+class SLANetDepthwiseSeparableConvLayer(PPLCNetDepthwiseSeparableConvLayer):
+ """
+ Depthwise Separable Convolution Layer: Depthwise Conv -> Pointwise Conv
+ Core component of lightweight models (e.g., MobileNet, PP-LCNet) that significantly reduces
+ the number of parameters and computational cost.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ kernel_size,
+ config,
+ ):
+ super().__init__()
+ self.squeeze_excitation_module = nn.Identity()
+
+
+class SLANetBottleneck(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ activation,
+ config,
+ ):
+ super().__init__()
+ self.conv1 = SLANetConvLayer(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, activation=activation
+ )
+ self.conv2 = SLANetDepthwiseSeparableConvLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ config=config,
+ )
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ hidden_states = self.conv1(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ return hidden_states
+
+
+class SLANetCSPLayer(nn.Module):
+ """
+ Cross Stage Partial (CSP) network layer. Similar in structure to DFineCSPRepLayer, but with a different forward computation.
+ """
+
+ def __init__(
+ self,
+ config,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ expansion=0.5,
+ num_blocks=1,
+ activation="hardswish",
+ ):
+ super().__init__()
+ hidden_channels = int(out_channels * expansion)
+ self.conv1 = SLANetConvLayer(in_channels, hidden_channels, 1, activation=activation)
+ self.conv2 = SLANetConvLayer(in_channels, hidden_channels, 1, activation=activation)
+ self.conv3 = SLANetConvLayer(2 * hidden_channels, out_channels, 1, activation=activation)
+ self.bottlenecks = nn.ModuleList(
+ [
+ SLANetBottleneck(hidden_channels, hidden_channels, kernel_size, activation, config)
+ for _ in range(num_blocks)
+ ]
+ )
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ residual = self.conv1(hidden_states)
+
+ hidden_states = self.conv2(hidden_states)
+ for bottleneck in self.bottlenecks:
+ hidden_states = bottleneck(hidden_states)
+
+ hidden_states = torch.cat((hidden_states, residual), dim=1)
+ hidden_states = self.conv3(hidden_states)
+
+ return hidden_states
+
+
+class SLANetCSPPAN(nn.Module):
+ """
+ CSP-PAN: Path Aggregation Network with CSP layers
+ """
+
+ def __init__(
+ self,
+ config,
+ in_channel_list,
+ ):
+ super().__init__()
+ out_channels = config.post_conv_out_channels
+ activation = config.hidden_act
+ kernel_size = config.csp_kernel_size
+ csp_num_blocks = config.csp_num_blocks
+
+ self.channel_projector = nn.ModuleList(
+ [
+ SLANetConvLayer(
+ in_channels=in_channel_list[i], out_channels=out_channels, kernel_size=1, activation=activation
+ )
+ for i in range(len(in_channel_list))
+ ]
+ )
+
+ # build top-down blocks
+ self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
+ self.top_down_blocks = nn.ModuleList(
+ [
+ SLANetCSPLayer(
+ config,
+ out_channels * 2,
+ out_channels,
+ kernel_size=kernel_size,
+ num_blocks=csp_num_blocks,
+ activation=activation,
+ )
+ for _ in range(len(in_channel_list) - 1, 0, -1)
+ ]
+ )
+
+ # build bottom-up blocks
+ self.downsamples = nn.ModuleList(
+ [
+ SLANetDepthwiseSeparableConvLayer(
+ out_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=2,
+ config=config,
+ )
+ for _ in range(len(in_channel_list) - 1)
+ ]
+ )
+ self.bottom_up_blocks = nn.ModuleList(
+ [
+ SLANetCSPLayer(
+ config,
+ out_channels * 2,
+ out_channels,
+ kernel_size=kernel_size,
+ num_blocks=csp_num_blocks,
+ activation=activation,
+ )
+ for _ in range(len(in_channel_list) - 1)
+ ]
+ )
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ projected_features = []
+ for idx in range(len(self.channel_projector)):
+ projected_features.append(self.channel_projector[idx](hidden_states[idx]))
+
+ top_down_features = [projected_features[-1]]
+ for top_down_block, low_level_feature in zip(self.top_down_blocks, reversed(projected_features[:-1])):
+ high_level_feature = top_down_features[-1]
+ upsampled_feature = F.interpolate(
+ high_level_feature,
+ size=low_level_feature.shape[-2:],
+ mode="nearest",
+ )
+ fused_feature = top_down_block(torch.cat([upsampled_feature, low_level_feature], dim=1))
+ top_down_features.append(fused_feature)
+
+ pyramid_features = list(reversed(top_down_features))
+ output_feature = pyramid_features[0]
+ for downsample_layer, bottom_up_block, high_level_feature in zip(
+ self.downsamples, self.bottom_up_blocks, pyramid_features[1:]
+ ):
+ downsampled_feature = downsample_layer(output_feature)
+ output_feature = bottom_up_block(torch.cat([downsampled_feature, high_level_feature], dim=1))
+
+ hidden_states = output_feature.flatten(2).transpose(1, 2)
+ return hidden_states
+
+
+class SLANetBackbone(SLANetPreTrainedModel):
+ def __init__(self, config: SLANetConfig):
+ super().__init__(config)
+ self.vision_backbone = load_backbone(config)
+ self.post_csp_pan = SLANetCSPPAN(config, self.vision_backbone.num_features[2:])
+
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self, hidden_states: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
+ ) -> tuple[torch.FloatTensor] | BaseModelOutputWithNoAttention:
+ outputs = self.vision_backbone(hidden_states, **kwargs)
+ hidden_states = self.post_csp_pan(outputs.feature_maps)
+ return BaseModelOutputWithNoAttention(
+ last_hidden_state=hidden_states,
+ hidden_states=outputs.hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ SLANet Table Recognition model for table recognition tasks. Wraps the core SLANetPreTrainedModel
+ and returns outputs compatible with the Transformers table recognition API.
+ """
+)
+class SLANetForTableRecognition(SLANeXtForTableRecognition):
+ _keys_to_ignore_on_load_missing = ["num_batches_tracked"]
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
+ ) -> tuple[torch.FloatTensor] | SLANetForTableRecognitionOutput:
+ outputs = self.backbone(pixel_values, **kwargs)
+ head_outputs = self.head(outputs.last_hidden_state, **kwargs)
+ # Key difference: no attentions in its vision model
+ return SLANetForTableRecognitionOutput(
+ last_hidden_state=head_outputs.last_hidden_state,
+ hidden_states=outputs.hidden_states,
+ head_hidden_states=head_outputs.hidden_states,
+ head_attentions=head_outputs.attentions,
+ )
+
+
+__all__ = ["SLANetConfig", "SLANetForTableRecognition", "SLANetPreTrainedModel", "SLANetSLAHead", "SLANetBackbone"]
diff --git a/tests/models/slanet/__init__.py b/tests/models/slanet/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/slanet/test_modeling_slanet.py b/tests/models/slanet/test_modeling_slanet.py
new file mode 100644
index 000000000000..b92fa56428c1
--- /dev/null
+++ b/tests/models/slanet/test_modeling_slanet.py
@@ -0,0 +1,245 @@
+# coding = utf-8
+# Copyright 2026 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the SLANet model."""
+
+import inspect
+import unittest
+
+from transformers import (
+ AutoImageProcessor,
+ AutoModelForTableRecognition,
+ SLANetConfig,
+ SLANetForTableRecognition,
+ is_torch_available,
+)
+from transformers.image_utils import load_image
+from transformers.testing_utils import (
+ require_torch,
+ require_vision,
+ slow,
+ torch_device,
+)
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor
+from ...test_pipeline_mixin import PipelineTesterMixin
+from ...test_processing_common import url_to_local_path
+
+
+if is_torch_available():
+ import torch
+
+
+class SLANetModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=2,
+ image_size=488,
+ num_channels=3,
+ post_conv_out_channels=16,
+ out_channels=1,
+ hidden_size=16,
+ max_text_length=1,
+ num_stages=5,
+ is_training=False,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.post_conv_out_channels = post_conv_out_channels
+ self.out_channels = out_channels
+ self.hidden_size = hidden_size
+ self.max_text_length = max_text_length
+ self.num_stages = num_stages
+ self.is_training = is_training
+
+ def prepare_config_and_inputs_for_common(self):
+ config, pixel_values = self.prepare_config_and_inputs()
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+ config = self.get_config()
+
+ return config, pixel_values
+
+ def get_config(self) -> SLANetConfig:
+ backbone_config = {
+ "model_type": "pp_lcnet",
+ "scale": 1,
+ "out_features": ["stage2", "stage3", "stage4", "stage5"],
+ "out_indices": [2, 3, 4, 5],
+ "block_configs": [
+ [[3, 16, 16, 1, False]],
+ [[3, 16, 16, 2, False], [3, 16, 16, 1, False]],
+ [[3, 16, 16, 2, False], [3, 16, 16, 1, False]],
+ [
+ [3, 16, 16, 2, False],
+ [5, 16, 16, 1, False],
+ [5, 16, 16, 1, False],
+ [5, 16, 16, 1, False],
+ [5, 16, 16, 1, False],
+ [5, 16, 16, 1, False],
+ ],
+ [[5, 16, 16, 2, True], [5, 16, 16, 1, True]],
+ ],
+ }
+ config = SLANetConfig(
+ backbone_config=backbone_config,
+ out_channels=self.out_channels,
+ hidden_size=self.hidden_size,
+ max_text_length=self.max_text_length,
+ post_conv_out_channels=self.post_conv_out_channels,
+ )
+
+ return config
+
+
+@require_torch
+class SLANetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ all_model_classes = (SLANetForTableRecognition,) if is_torch_available() else ()
+ pipeline_model_mapping = {"image-feature-extraction": SLANetForTableRecognition} if is_torch_available() else {}
+
+ has_attentions = False
+ test_resize_embeddings = False
+ test_torch_exportable = False
+
+ def setUp(self):
+ self.model_tester = SLANetModelTester(
+ self,
+ batch_size=1,
+ image_size=488,
+ )
+ self.config_tester = ConfigTester(
+ self,
+ config_class=SLANetConfig,
+ has_text_modality=False,
+ common_properties=[],
+ )
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="SLANet does not use inputs_embeds")
+ def test_enable_input_require_grads(self):
+ pass
+
+ @unittest.skip(reason="SLANet does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="SLANet does not use test_inputs_embeds_matches_input_ids")
+ def test_inputs_embeds_matches_input_ids(self):
+ pass
+
+ @unittest.skip(reason="SLANet does not support input and output embeddings")
+ def test_model_get_set_embeddings(self):
+ pass
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ arg_names = [*signature.parameters.keys()]
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ # SLANet have no seq_length
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.hidden_states
+ expected_num_stages = self.model_tester.num_stages
+
+ self.assertEqual(len(hidden_states), expected_num_stages + 1)
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict.copy(), config, model_class)
+
+ # Check that output_hidden_states also works via config (including backbone subconfig)
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+ if config.backbone_config is not None:
+ config.backbone_config.output_hidden_states = True
+ check_hidden_states_output(inputs_dict.copy(), config, model_class)
+
+
+@require_torch
+@require_vision
+@slow
+class SLANetModelIntegrationTest(unittest.TestCase):
+ def setUp(self):
+ model_path = "PaddlePaddle/SLANet_plus_safetensors"
+ self.model = AutoModelForTableRecognition.from_pretrained(model_path, dtype=torch.float32).to(torch_device)
+ self.image_processor = AutoImageProcessor.from_pretrained(model_path)
+ img_url = url_to_local_path(
+ "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg"
+ )
+ self.image = load_image(img_url)
+
+ def test_inference_table_recognition_head(self):
+ inputs = self.image_processor(images=self.image, return_tensors="pt").to(torch_device)
+
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+
+ pred_table_structure = self.image_processor.post_process_table_recognition(outputs)["structure"]
+ expected_table_structure = [
+ "",
+ "",
+ "",
+ "",
+ "| ",
+ " | ",
+ "
",
+ "",
+ " | ",
+ " | ",
+ " | ",
+ " | ",
+ "
",
+ "",
+ " | ",
+ " | ",
+ " | ",
+ " | ",
+ "
",
+ "",
+ " | ",
+ " | ",
+ " | ",
+ " | ",
+ "
",
+ "
",
+ "",
+ "",
+ ]
+
+ self.assertEqual(pred_table_structure, expected_table_structure)
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 0816e834c64b..5a5e4cea1c74 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -194,6 +194,8 @@
"SeamlessM4TCodeHifiGan", # Building part of bigger (tested) model.
"SeamlessM4TTextToUnitForConditionalGeneration", # Building part of bigger (tested) model.
"ChameleonVQVAE", # VQVAE here is used only for encoding (discretizing) and is tested as part of bigger model
+ "SLANetSLAHead", # Buildinere is used only for encoding (discretizing) and is tested as part of bigger model
+ "SLANetBackbone", # Buildinere is used only for encoding (discretizing) and is tested as part of bigger model
"SLANeXtSLAHead", # Buildinere is used only for encoding (discretizing) and is tested as part of bigger model
"SLANeXtBackbone", # Building part of bigger (tested) model. Tested implicitly through SLANeXtForTableRecognition.
"PPOCRV5MobileDetModel", # Building part of bigger (tested) model. Tested implicitly through PPOCRV5MobileDetForObjectDetection.
@@ -460,6 +462,8 @@
"Emu3TextModel", # Building part of bigger (tested) model
"JanusVQVAE", # no autoclass for VQ-VAE models
"JanusVisionModel", # Building part of bigger (tested) model
+ "SLANetSLAHead", # Building part of bigger (tested) model
+ "SLANetBackbone", # Building part of bigger (tested) model
"SLANeXtSLAHead", # Building part of bigger (tested) model
"SLANeXtBackbone", # Building part of bigger (tested) model
"PPOCRV5MobileDetModel", # Building part of bigger (tested) model
diff --git a/utils/fetch_hub_objects_for_ci.py b/utils/fetch_hub_objects_for_ci.py
index 3d229637df70..3b3609d21ce3 100644
--- a/utils/fetch_hub_objects_for_ci.py
+++ b/utils/fetch_hub_objects_for_ci.py
@@ -39,6 +39,7 @@
URLS_FOR_TESTING_DATA = [
# TODO: copy those to our hf-internal-testing dataset and fix all tests using them
+ "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg",
"http://images.cocodataset.org/val2017/000000000139.jpg",
"http://images.cocodataset.org/val2017/000000000285.jpg",
"http://images.cocodataset.org/val2017/000000000632.jpg",