diff --git a/src/transformers/models/conditional_detr/image_processing_pil_conditional_detr.py b/src/transformers/models/conditional_detr/image_processing_pil_conditional_detr.py index 359c4c706f7c..30740114d5f0 100644 --- a/src/transformers/models/conditional_detr/image_processing_pil_conditional_detr.py +++ b/src/transformers/models/conditional_detr/image_processing_pil_conditional_detr.py @@ -61,6 +61,8 @@ logger = logging.get_logger(__name__) +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) + class ConditionalDetrImageProcessorKwargs(ImagesKwargs, total=False): r""" @@ -76,9 +78,6 @@ class ConditionalDetrImageProcessorKwargs(ImagesKwargs, total=False): do_convert_annotations: bool -SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) - - # inspired by https://github.com/facebookresearch/conditional_detr/blob/master/datasets/coco.py#L33 def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndarray: """ diff --git a/src/transformers/models/conditional_detr/modular_conditional_detr.py b/src/transformers/models/conditional_detr/modular_conditional_detr.py index ffc1e78bee01..2205b85c5547 100644 --- a/src/transformers/models/conditional_detr/modular_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modular_conditional_detr.py @@ -20,13 +20,12 @@ from ...image_transforms import ( center_to_corners_format, ) -from ...image_utils import AnnotationFormat from ...masking_utils import create_bidirectional_mask from ...modeling_outputs import ( BaseModelOutput, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -from ...processing_utils import ImagesKwargs, Unpack +from ...processing_utils import Unpack from ...utils import ( TensorType, TransformersKwargs, @@ -66,20 +65,6 @@ logger = logging.get_logger(__name__) -class ConditionalDetrImageProcessorKwargs(ImagesKwargs, total=False): - r""" - format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): - Data format of the annotations. One of "coco_detection" or "coco_panoptic". - do_convert_annotations (`bool`, *optional*, defaults to `True`): - Controls whether to convert the annotations to the format expected by the CONDITIONAL_DETR model. Converts the - bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. - Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. - """ - - format: str | AnnotationFormat - do_convert_annotations: bool - - class ConditionalDetrImageProcessor(DetrImageProcessor): def post_process_object_detection( self, outputs, threshold: float = 0.5, target_sizes: TensorType | list[tuple] = None, top_k: int = 100 diff --git a/src/transformers/models/deepseek_vl/modular_deepseek_vl.py b/src/transformers/models/deepseek_vl/modular_deepseek_vl.py index a56da6f3fe0a..be955c6fd41e 100644 --- a/src/transformers/models/deepseek_vl/modular_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/modular_deepseek_vl.py @@ -20,7 +20,7 @@ from ...configuration_utils import PreTrainedConfig from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import ( PreTokenizedInput, TextInput, @@ -152,16 +152,6 @@ def generate(self): raise AttributeError("Not needed for DeepseekVL") -class DeepseekVLImageProcessorKwargs(ImagesKwargs, total=False): - r""" - min_size (`int`, *optional*, defaults to 14): - The minimum allowed size for the resized image. Ensures that neither the height nor width - falls below this value after resizing. - """ - - min_size: int - - class DeepseekVLImageProcessorPil(JanusImageProcessorPil): def postprocess(self): raise AttributeError("Not needed for DeepseekVL") diff --git a/src/transformers/models/deformable_detr/image_processing_pil_deformable_detr.py b/src/transformers/models/deformable_detr/image_processing_pil_deformable_detr.py index fcd95fa4647f..9c7ccc213910 100644 --- a/src/transformers/models/deformable_detr/image_processing_pil_deformable_detr.py +++ b/src/transformers/models/deformable_detr/image_processing_pil_deformable_detr.py @@ -57,6 +57,8 @@ if is_torch_available(): import torch +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) + class DeformableDetrImageProcessorKwargs(ImagesKwargs, total=False): r""" @@ -72,9 +74,6 @@ class DeformableDetrImageProcessorKwargs(ImagesKwargs, total=False): do_convert_annotations: bool -SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) - - # inspired by https://github.com/facebookresearch/deformable_detr/blob/master/datasets/coco.py#L33 def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndarray: """ diff --git a/src/transformers/models/deformable_detr/modular_deformable_detr.py b/src/transformers/models/deformable_detr/modular_deformable_detr.py index a2f80e8236ad..a4a5b4acd95a 100644 --- a/src/transformers/models/deformable_detr/modular_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modular_deformable_detr.py @@ -23,11 +23,10 @@ from ... import initialization as init from ...backbone_utils import load_backbone from ...image_transforms import center_to_corners_format -from ...image_utils import AnnotationFormat from ...integrations import use_kernel_forward_from_hub from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...processing_utils import ImagesKwargs, Unpack +from ...processing_utils import Unpack from ...utils import ( ModelOutput, TensorType, @@ -61,20 +60,6 @@ logger = logging.get_logger(__name__) -class DeformableDetrImageProcessorKwargs(ImagesKwargs, total=False): - r""" - format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): - Data format of the annotations. One of "coco_detection" or "coco_panoptic". - do_convert_annotations (`bool`, *optional*, defaults to `True`): - Controls whether to convert the annotations to the format expected by the DEFORMABLE_DETR model. Converts the - bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. - Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. - """ - - format: str | AnnotationFormat - do_convert_annotations: bool - - class DeformableDetrImageProcessor(DetrImageProcessor): def post_process_object_detection( self, outputs, threshold: float = 0.5, target_sizes: TensorType | list[tuple] = None, top_k: int = 100 diff --git a/src/transformers/models/efficientloftr/modular_efficientloftr.py b/src/transformers/models/efficientloftr/modular_efficientloftr.py index 17e3e399a8df..86d8d34eba70 100644 --- a/src/transformers/models/efficientloftr/modular_efficientloftr.py +++ b/src/transformers/models/efficientloftr/modular_efficientloftr.py @@ -1,6 +1,5 @@ from typing import TYPE_CHECKING -from ...processing_utils import ImagesKwargs from ...utils import TensorType, is_torch_available from ...utils.import_utils import requires from ..superglue.image_processing_pil_superglue import SuperGlueImageProcessorPil @@ -14,15 +13,6 @@ from .modeling_efficientloftr import EfficientLoFTRKeypointMatchingOutput -class EfficientLoFTRImageProcessorKwargs(ImagesKwargs, total=False): - r""" - do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`): - Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method. - """ - - do_grayscale: bool - - class EfficientLoFTRImageProcessor(SuperGlueImageProcessor): def post_process_keypoint_matching( self, diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index 42bbb44b70a5..ad47bc0508a3 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -43,7 +43,7 @@ from ...modeling_outputs import BaseModelOutputWithPooling, MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import dynamic_rope_update from ...modeling_utils import PreTrainedModel -from ...processing_utils import ImagesKwargs, Unpack +from ...processing_utils import Unpack from ...utils import ( TensorType, TransformersKwargs, @@ -63,7 +63,7 @@ Ernie4_5_MoeStatics, Ernie4_5_MoeTopKRouter, ) -from ..glm4v.image_processing_glm4v import Glm4vImageProcessor +from ..glm4v.image_processing_glm4v import Glm4vImageProcessor, Glm4vImageProcessorKwargs from ..glm4v.image_processing_pil_glm4v import Glm4vImageProcessorPil from ..glm4v.modeling_glm4v import Glm4vForConditionalGeneration from ..mixtral.modeling_mixtral import load_balancing_loss_func @@ -1220,7 +1220,7 @@ def forward( ) -class Ernie4_5_VLMoeImageProcessorKwargs(ImagesKwargs, total=False): +class Ernie4_5_VLMoeImageProcessorKwargs(Glm4vImageProcessorKwargs): r""" patch_size (`int`, *optional*, defaults to 14): The spatial patch size of the vision encoder. @@ -1230,10 +1230,6 @@ class Ernie4_5_VLMoeImageProcessorKwargs(ImagesKwargs, total=False): The merge size of the vision encoder to llm encoder. """ - patch_size: int - temporal_patch_size: int - merge_size: int - class Ernie4_5_VLMoeImageProcessorPil(Glm4vImageProcessorPil): size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 6177} diff --git a/src/transformers/models/glm_image/image_processing_pil_glm_image.py b/src/transformers/models/glm_image/image_processing_pil_glm_image.py index 2dde18ef2066..355bb04adb67 100644 --- a/src/transformers/models/glm_image/image_processing_pil_glm_image.py +++ b/src/transformers/models/glm_image/image_processing_pil_glm_image.py @@ -30,7 +30,6 @@ from ...utils import TensorType, auto_docstring -# Adapted from transformers.models.glm_image.image_processing_glm_image.GlmImageImageProcessorKwargs class GlmImageImageProcessorKwargs(ImagesKwargs, total=False): r""" min_pixels (`int`, *optional*, defaults to `56 * 56`): diff --git a/src/transformers/models/grounding_dino/image_processing_pil_grounding_dino.py b/src/transformers/models/grounding_dino/image_processing_pil_grounding_dino.py index 31c59e5f3930..c95d7cb386bd 100644 --- a/src/transformers/models/grounding_dino/image_processing_pil_grounding_dino.py +++ b/src/transformers/models/grounding_dino/image_processing_pil_grounding_dino.py @@ -67,6 +67,8 @@ if is_torch_available(): import torch +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) + class GroundingDinoImageProcessorKwargs(ImagesKwargs, total=False): r""" @@ -82,9 +84,6 @@ class GroundingDinoImageProcessorKwargs(ImagesKwargs, total=False): do_convert_annotations: bool -SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) - - # inspired by https://github.com/facebookresearch/grounding_dino/blob/master/datasets/coco.py#L33 def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndarray: """ diff --git a/src/transformers/models/grounding_dino/modular_grounding_dino.py b/src/transformers/models/grounding_dino/modular_grounding_dino.py index 483ad262a602..bd35fd512ffe 100644 --- a/src/transformers/models/grounding_dino/modular_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modular_grounding_dino.py @@ -25,8 +25,6 @@ from transformers.models.detr.image_processing_pil_detr import DetrImageProcessorPil from ...image_transforms import center_to_corners_format -from ...image_utils import AnnotationFormat -from ...processing_utils import ImagesKwargs from ...utils import ( TensorType, logging, @@ -70,20 +68,6 @@ def _scale_boxes(boxes, target_sizes): return boxes -class GroundingDinoImageProcessorKwargs(ImagesKwargs, total=False): - r""" - format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): - Data format of the annotations. One of "coco_detection" or "coco_panoptic". - do_convert_annotations (`bool`, *optional*, defaults to `True`): - Controls whether to convert the annotations to the format expected by the GROUNDING_DINO model. Converts the - bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. - Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. - """ - - format: str | AnnotationFormat - do_convert_annotations: bool - - class GroundingDinoImageProcessor(DetrImageProcessor): def post_process_object_detection( self, diff --git a/src/transformers/models/lightglue/modular_lightglue.py b/src/transformers/models/lightglue/modular_lightglue.py index 62082b678b00..afc8a3efec25 100644 --- a/src/transformers/models/lightglue/modular_lightglue.py +++ b/src/transformers/models/lightglue/modular_lightglue.py @@ -23,7 +23,7 @@ from ...configuration_utils import PreTrainedConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...processing_utils import ImagesKwargs, Unpack +from ...processing_utils import Unpack from ...utils import ModelOutput, TensorType, auto_docstring, can_return_tuple, logging from ...utils.import_utils import requires from ..auto import CONFIG_MAPPING, AutoConfig @@ -32,7 +32,7 @@ from ..cohere.modeling_cohere import apply_rotary_pos_emb from ..llama.modeling_llama import LlamaAttention, eager_attention_forward from ..superglue.image_processing_pil_superglue import SuperGlueImageProcessorPil -from ..superglue.image_processing_superglue import SuperGlueImageProcessor +from ..superglue.image_processing_superglue import SuperGlueImageProcessor, SuperGlueImageProcessorKwargs from ..superpoint import SuperPointConfig @@ -154,13 +154,8 @@ class LightGlueKeypointMatchingOutput(ModelOutput): attentions: tuple[torch.FloatTensor] | None = None -class LightGlueImageProcessorKwargs(ImagesKwargs, total=False): - r""" - do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`): - Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method. - """ - - do_grayscale: bool +class LightGlueImageProcessorKwargs(SuperGlueImageProcessorKwargs): + pass class LightGlueImageProcessor(SuperGlueImageProcessor): diff --git a/src/transformers/models/llava_onevision/modular_llava_onevision.py b/src/transformers/models/llava_onevision/modular_llava_onevision.py index a3634aa17cba..f44a4612cdc2 100644 --- a/src/transformers/models/llava_onevision/modular_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modular_llava_onevision.py @@ -34,7 +34,7 @@ ) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPooling -from ...processing_utils import ImagesKwargs, Unpack +from ...processing_utils import Unpack from ...utils import TensorType, auto_docstring, logging from ...utils.generic import can_return_tuple, merge_with_config_defaults from ..llava_next.image_processing_llava_next import LlavaNextImageProcessor, LlavaNextImageProcessorKwargs @@ -217,17 +217,6 @@ def _preprocess( ) -class LlavaOnevisionImageProcessorKwargs(ImagesKwargs, total=False): - r""" - image_grid_pinpoints (`list[list[int]]`, *optional*): - A list of possible resolutions to use for processing high resolution images. The best resolution is selected - based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess` - method. - """ - - image_grid_pinpoints: list[list[int]] - - class LlavaOnevisionImageProcessorPil(LlavaNextImageProcessorPil): resample = PILImageResampling.BICUBIC image_mean = OPENAI_CLIP_MEAN diff --git a/src/transformers/models/mask2former/modular_mask2former.py b/src/transformers/models/mask2former/modular_mask2former.py index 87f2b834991f..089baffe5df7 100644 --- a/src/transformers/models/mask2former/modular_mask2former.py +++ b/src/transformers/models/mask2former/modular_mask2former.py @@ -15,8 +15,6 @@ import torch from torch import nn -from ...image_utils import SizeDict -from ...processing_utils import ImagesKwargs from ...utils import ( TensorType, logging, @@ -35,32 +33,6 @@ logger = logging.get_logger(__name__) -class Mask2FormerImageProcessorKwargs(ImagesKwargs, total=False): - r""" - ignore_index (`int`, *optional*): - Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels - denoted with 0 (background) will be replaced with `ignore_index`. - do_reduce_labels (`bool`, *optional*, defaults to `False`): - Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0 - is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). - The background label will be replaced by `ignore_index`. - num_labels (`int`, *optional*): - The number of labels in the segmentation map. - size_divisor (`int`, *optional*, defaults to `32`): - Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in - Swin Transformer. - pad_size (`SizeDict`, *optional*): - The size to pad the images to. Must be larger than any image size provided for preprocessing. If `pad_size` - is not provided, images will be padded to the largest height and width in the batch. - """ - - ignore_index: int | None - do_reduce_labels: bool - num_labels: int | None - size_divisor: int - pad_size: SizeDict | None - - class Mask2FormerImageProcessor(MaskFormerImageProcessor): def post_process_semantic_segmentation( self, outputs, target_sizes: list[tuple[int, int]] | None = None diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 20a897059a4e..02895d6e2576 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -38,9 +38,8 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...models.qwen2_vl.image_processing_pil_qwen2_vl import Qwen2VLImageProcessorPil -from ...models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor +from ...models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor, Qwen2VLImageProcessorKwargs from ...processing_utils import ( - ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, @@ -123,7 +122,7 @@ def smart_resize( return h_bar, w_bar -class PaddleOCRVLImageProcessorKwargs(ImagesKwargs, total=False): +class PaddleOCRVLImageProcessorKwargs(Qwen2VLImageProcessorKwargs): r""" patch_size (`int`, *optional*, defaults to 14): The spatial patch size of the vision encoder. @@ -133,12 +132,6 @@ class PaddleOCRVLImageProcessorKwargs(ImagesKwargs, total=False): The merge size of the vision encoder to llm encoder. """ - min_pixels: int - max_pixels: int - patch_size: int - temporal_patch_size: int - merge_size: int - class PaddleOCRVLImageProcessorPil(Qwen2VLImageProcessorPil): size = {"shortest_edge": 384 * 384, "longest_edge": 1536 * 1536} diff --git a/src/transformers/models/rt_detr/image_processing_pil_rt_detr.py b/src/transformers/models/rt_detr/image_processing_pil_rt_detr.py index 1fe55d067653..669843e9f949 100644 --- a/src/transformers/models/rt_detr/image_processing_pil_rt_detr.py +++ b/src/transformers/models/rt_detr/image_processing_pil_rt_detr.py @@ -54,6 +54,8 @@ if is_torch_available(): import torch +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) + class RTDetrImageProcessorKwargs(ImagesKwargs, total=False): r""" @@ -69,9 +71,6 @@ class RTDetrImageProcessorKwargs(ImagesKwargs, total=False): do_convert_annotations: bool -SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) - - def prepare_coco_detection_annotation_pil( image, target, diff --git a/src/transformers/models/rt_detr/modular_rt_detr.py b/src/transformers/models/rt_detr/modular_rt_detr.py index 97136541d6ec..cd4e8faf3fc2 100644 --- a/src/transformers/models/rt_detr/modular_rt_detr.py +++ b/src/transformers/models/rt_detr/modular_rt_detr.py @@ -426,20 +426,6 @@ def post_process_panoptic_segmentation(self): raise NotImplementedError("Panoptic segmentation post-processing is not implemented for RT-DETR yet.") -class RTDetrImageProcessorKwargs(ImagesKwargs, total=False): - r""" - format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): - Data format of the annotations. One of "coco_detection" or "coco_panoptic". - do_convert_annotations (`bool`, *optional*, defaults to `True`): - Controls whether to convert the annotations to the format expected by the RT_DETR model. Converts the - bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. - Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. - """ - - format: str | AnnotationFormat - do_convert_annotations: bool - - @requires(backends=("torch",)) class RTDetrImageProcessorPil(DetrImageProcessorPil): resample = PILImageResampling.BILINEAR diff --git a/src/transformers/models/segformer/modular_segformer.py b/src/transformers/models/segformer/modular_segformer.py index d7f339ea6e42..414dc58e8c52 100644 --- a/src/transformers/models/segformer/modular_segformer.py +++ b/src/transformers/models/segformer/modular_segformer.py @@ -31,22 +31,10 @@ PILImageResampling, SizeDict, ) -from ...processing_utils import ImagesKwargs from ...utils import TensorType from ...utils.import_utils import requires -class SegformerImageProcessorKwargs(ImagesKwargs, total=False): - r""" - do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): - Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 - is used for background, and background itself is not included in all classes of a dataset (e.g. - ADE20k). The background label will be replaced by 255. - """ - - do_reduce_labels: bool - - class SegformerImageProcessor(BeitImageProcessor): resample = PILImageResampling.BILINEAR image_mean = IMAGENET_DEFAULT_MEAN diff --git a/src/transformers/models/smolvlm/modular_smolvlm.py b/src/transformers/models/smolvlm/modular_smolvlm.py index 9c572cc9d877..cf91863c56a7 100644 --- a/src/transformers/models/smolvlm/modular_smolvlm.py +++ b/src/transformers/models/smolvlm/modular_smolvlm.py @@ -22,7 +22,7 @@ from ...generation import GenerationConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPooling -from ...processing_utils import ImagesKwargs, Unpack +from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check from ..idefics3.configuration_idefics3 import Idefics3Config, Idefics3VisionConfig from ..idefics3.image_processing_idefics3 import Idefics3ImageProcessor @@ -91,22 +91,6 @@ class SmolVLMConfig(Idefics3Config): model_type = "smolvlm" -class SmolVLMImageProcessorKwargs(ImagesKwargs, total=False): - """ - do_image_splitting (`bool`, *optional*, defaults to `True`): - Whether to split the image into sub-images concatenated with the original image. They are split into patches - such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`. - max_image_size (`Dict`, *optional*, defaults to `{"longest_edge": 364}`): - Maximum resolution of the patches of images accepted by the model. This is a dictionary containing the key "longest_edge". - return_row_col_info (`bool`, *optional*, defaults to `False`): - Whether to return the row and column information of the images. - """ - - do_image_splitting: bool - max_image_size: dict[str, int] - return_row_col_info: bool - - class SmolVLMImageProcessor(Idefics3ImageProcessor): pass diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index 4eef74580c87..c4a9e40bc8f0 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -37,7 +37,7 @@ ) from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...processing_utils import ImagesKwargs, Unpack +from ...processing_utils import Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import ( TensorType, @@ -55,7 +55,7 @@ from ..auto import CONFIG_MAPPING, AutoConfig from ..auto.modeling_auto import AutoModel from ..qwen2_vl.image_processing_pil_qwen2_vl import Qwen2VLImageProcessorPil -from ..qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor, smart_resize +from ..qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor, Qwen2VLImageProcessorKwargs, smart_resize from ..qwen2_vl.modeling_qwen2_vl import ( Qwen2VLForConditionalGeneration, Qwen2VLModel, @@ -1107,25 +1107,8 @@ def model_input_names(self): raise AttributeError("VideoLlama doesn't need to override it") -class VideoLlama3ImageProcessorKwargs(ImagesKwargs, total=False): - r""" - min_pixels (`int`, *optional*, defaults to `56 * 56`): - The min pixels of the image to resize the image. - max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`): - The max pixels of the image to resize the image. - patch_size (`int`, *optional*, defaults to 14): - The spatial patch size of the vision encoder. - temporal_patch_size (`int`, *optional*, defaults to 2): - The temporal patch size of the vision encoder. - merge_size (`int`, *optional*, defaults to 2): - The merge size of the vision encoder to llm encoder. - """ - - min_pixels: int - max_pixels: int - patch_size: int - temporal_patch_size: int - merge_size: int +class VideoLlama3ImageProcessorKwargs(Qwen2VLImageProcessorKwargs): + pass class VideoLlama3ImageProcessorPil(Qwen2VLImageProcessorPil): diff --git a/src/transformers/models/yolos/image_processing_pil_yolos.py b/src/transformers/models/yolos/image_processing_pil_yolos.py index 219348363ea3..f42fb5a63701 100644 --- a/src/transformers/models/yolos/image_processing_pil_yolos.py +++ b/src/transformers/models/yolos/image_processing_pil_yolos.py @@ -44,8 +44,9 @@ import torch from torch import nn +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) + -# Adapted from transformers.models.yolos.image_processing_yolos.YolosImageProcessorKwargs class YolosImageProcessorKwargs(ImagesKwargs, total=False): r""" format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): @@ -60,9 +61,6 @@ class YolosImageProcessorKwargs(ImagesKwargs, total=False): do_convert_annotations: bool -SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) - - # inspired by https://github.com/facebookresearch/yolos/blob/master/datasets/coco.py#L33 def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndarray: """ diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 48dc46b8b593..018316680ece 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1303,6 +1303,47 @@ def _code(node: cst.CSTNode) -> str: return other_imports + result +def replace_unprotected_image_processing_imports(files: dict, all_imports: list) -> dict: + """ + Because `image_processing` file uses non-protected torchvision and torch imports, we need to duplicate the nodes + inside `image_processing_pil` instead of importing them directly from `.image_processing_xxx`, which would crash if + torchvision is not installed. + """ + if not ("image_processing" in files and "image_processing_pil" in files): + return files + + body = files["image_processing_pil"] + needed_imports = get_needed_imports(body, all_imports) + import_from_image_processing = None + for import_node in needed_imports: + if isinstance(import_node, cst.SimpleStatementLine) and isinstance(import_node.body[0], cst.ImportFrom): + import_node = import_node.body[0] + full_name = get_full_attribute_name(import_node.module) + if re.search(r"^image_processing_(?!(?:backends)|(?:utils))", full_name): + import_from_image_processing = import_node + break + + if import_from_image_processing is None: + return files + + imported_objects = [x.name.value for x in import_from_image_processing.names] + nodes_to_add = {name: files["image_processing"][name] for name in imported_objects} + # Update the position inside the final file + for name, node_structure in nodes_to_add.items(): + node_with_same_index = next( + v["node"] for v in body.values() if v["insert_idx"] == node_structure["insert_idx"] + ) + # Insert the new node before the corresponding node if the corresponding node is a class or function + if isinstance(node_with_same_index, (cst.ClassDef, cst.FunctionDef)): + nodes_to_add[name]["insert_idx"] -= 0.5 + # Otherwise, after it + else: + nodes_to_add[name]["insert_idx"] += 0.5 + # Add the nodes inside the body of `image_processing_pil` + body.update(nodes_to_add) + return files + + def split_all_assignment(node: cst.CSTNode, model_name: str) -> dict[str, cst.CSTNode]: """Split the `__all__` assignment found in the modular between each corresponding files.""" all_all_per_file = {} @@ -1692,10 +1733,6 @@ class NewNameModel(LlamaModel): class_file_type = find_file_type(class_name, new_name) # In this case, we need to remove it from the dependencies and create a new import instead if class_file_type != file_type: - # image_processing_pil and image_processing must never depend on each other. - # When a PIL class needs an image_processing class, inline it instead of importing. - if file_type == "image_processing_pil" and class_file_type == "image_processing": - continue corrected_dependencies.remove(class_name) import_statement = f"from .{class_file_type}_{new_name} import {class_name}" new_imports[class_name] = cst.parse_statement(import_statement) @@ -1748,14 +1785,7 @@ class node based on the inherited classes if needed. Also returns any new import # Remove all classes explicitly defined in modular from the dependencies. Otherwise, if a class is referenced # before its new modular definition, it may be wrongly imported from elsewhere as a dependency if it matches # another class from a modeling file after renaming, even though it would be added after anyway (leading to duplicates) - # Exception: for image_processing_pil files, image_processing modular classes must be inlined (not excluded), - # because these two files must never import from each other. - classes_to_exclude = set(modular_mapper.classes.keys()) - if file_type == "image_processing_pil": - classes_to_exclude -= { - k for k in classes_to_exclude if find_file_type(k, model_name) == "image_processing" - } - new_node_dependencies -= classes_to_exclude + new_node_dependencies -= set(modular_mapper.classes.keys()) # The node was modified -> look for all recursive dependencies of the new node all_dependencies_to_add = find_all_dependencies( @@ -1766,13 +1796,7 @@ class node based on the inherited classes if needed. Also returns any new import relative_dependency_order = mapper.compute_relative_order(all_dependencies_to_add) nodes_to_add = { - dep: ( - relative_dependency_order[dep], - # If this dependency is explicitly defined in the modular, prefer the modular's version. - # This prevents a renamed parent class from overriding a modular-defined class of the same name. - modular_mapper.global_nodes[dep] if dep in modular_mapper.classes else mapper.global_nodes[dep], - ) - for dep in all_dependencies_to_add + dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) for dep in all_dependencies_to_add } # No transformers (modeling file) super class, just check functions and assignments dependencies @@ -1862,6 +1886,11 @@ def create_modules( all_imports.extend(new_imports) all_imports_code.update(new_imports_code) + # Because `image_processing` file uses non-protected torchvision and torch imports, we need to duplicate the nodes + # here instead of importing from `.image_processing_model`, which would crash if torchvision is not installed + if "image_processing" in files and "image_processing_pil" in files: + files = replace_unprotected_image_processing_imports(files, all_imports) + # Find the correct imports, and write the new modules for file, body in files.items(): new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])]