From 3c20ae0ab2a31b35513afa67bab68a14867b15fa Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 14 Jan 2025 21:02:30 +0000 Subject: [PATCH 1/7] uniformize owlvit processor --- .../models/owlvit/processing_owlvit.py | 105 ++++++++++++------ 1 file changed, 74 insertions(+), 31 deletions(-) diff --git a/src/transformers/models/owlvit/processing_owlvit.py b/src/transformers/models/owlvit/processing_owlvit.py index 49e913a384eb..2b73f4bc9a49 100644 --- a/src/transformers/models/owlvit/processing_owlvit.py +++ b/src/transformers/models/owlvit/processing_owlvit.py @@ -17,15 +17,40 @@ """ import warnings -from typing import List +from typing import List, Optional, Union import numpy as np -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ( + ImagesKwargs, + ProcessingKwargs, + ProcessorMixin, + Unpack, + _validate_images_text_input_order, +) +from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import is_flax_available, is_tf_available, is_torch_available +class OwlViTImagesKwargs(ImagesKwargs, total=False): + query_images: Optional[ImageInput] + + +class OwlViTProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: OwlViTImagesKwargs + _defaults = { + "text_kwargs": { + "padding": "max_length", + }, + "images_kwargs": {}, + "common_kwargs": { + "return_tensors": "np", + }, + } + + class OwlViTProcessor(ProcessorMixin): r""" Constructs an OWL-ViT processor which wraps [`OwlViTImageProcessor`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`] @@ -42,6 +67,8 @@ class OwlViTProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "OwlViTImageProcessor" tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast") + # For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. + optional_call_args = ["query_images"] def __init__(self, image_processor=None, tokenizer=None, **kwargs): feature_extractor = None @@ -61,7 +88,20 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): super().__init__(image_processor, tokenizer) - def __call__(self, text=None, images=None, query_images=None, padding="max_length", return_tensors="np", **kwargs): + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + # The following is to capture `query_images` argument that may be passed as a positional argument. + # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, + # or this conversation for more context: https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 + # This behavior is only needed for backward compatibility and will be removed in future versions. + # + *args, + audio=None, + videos=None, + **kwargs: Unpack[OwlViTProcessorKwargs], + ) -> BatchFeature: """ Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode: @@ -70,14 +110,15 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt of the above two methods for more information. Args: - text (`str`, `List[str]`, `List[List[str]]`): - The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings - (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set - `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. + + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). query_images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The query image to be prepared, one query image is expected per target image to be queried. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image @@ -88,23 +129,36 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - `'jax'`: Return JAX `jnp.ndarray` objects. + Returns: - [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **query_pixel_values** -- Pixel values of the query images to be fed to a model. Returned when `query_images` is not `None`. """ + output_kwargs = self._merge_kwargs( + OwlViTProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + **self.prepare_and_validate_optional_call_args(*args), + ) + query_images = output_kwargs["images_kwargs"].pop("query_images", None) + return_tensors = output_kwargs["common_kwargs"]["return_tensors"] if text is None and query_images is None and images is None: raise ValueError( "You have to specify at least one text or query image or image. All three cannot be none." ) + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + data = {} if text is not None: if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)): - encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)] + encodings = [self.tokenizer(text, **output_kwargs["text_kwargs"])] elif isinstance(text, List) and isinstance(text[0], List): encodings = [] @@ -117,7 +171,7 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt if len(t) != max_num_queries: t = t + [" "] * (max_num_queries - len(t)) - encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs) + encoding = self.tokenizer(t, **output_kwargs["text_kwargs"]) encodings.append(encoding) else: raise TypeError("Input text should be a string, a list of strings or a nested list of strings") @@ -147,30 +201,19 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt else: raise ValueError("Target return tensor type could not be returned") - encoding = BatchEncoding() - encoding["input_ids"] = input_ids - encoding["attention_mask"] = attention_mask + data["input_ids"] = input_ids + data["attention_mask"] = attention_mask if query_images is not None: - encoding = BatchEncoding() - query_pixel_values = self.image_processor( - query_images, return_tensors=return_tensors, **kwargs - ).pixel_values - encoding["query_pixel_values"] = query_pixel_values + query_pixel_values = self.image_processor(query_images, **output_kwargs["images_kwargs"]).pixel_values + # Query images always override the text prompt + data = {"query_pixel_values": query_pixel_values} if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) - - if text is not None and images is not None: - encoding["pixel_values"] = image_features.pixel_values - return encoding - elif query_images is not None and images is not None: - encoding["pixel_values"] = image_features.pixel_values - return encoding - elif text is not None or query_images is not None: - return encoding - else: - return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + data["pixel_values"] = image_features.pixel_values + + return BatchFeature(data=data, tensor_type=return_tensors) def post_process(self, *args, **kwargs): """ From 1edc07d9236388b5262db6a1cee5b95c4d2f87fd Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 14 Jan 2025 21:18:24 +0000 Subject: [PATCH 2/7] uniformize owlv2 --- .../models/owlv2/processing_owlv2.py | 117 ++++++++++++------ tests/models/owlv2/test_processor_owlv2.py | 38 ++++++ 2 files changed, 118 insertions(+), 37 deletions(-) create mode 100644 tests/models/owlv2/test_processor_owlv2.py diff --git a/src/transformers/models/owlv2/processing_owlv2.py b/src/transformers/models/owlv2/processing_owlv2.py index 4a0b5a712e9d..14093a37ec97 100644 --- a/src/transformers/models/owlv2/processing_owlv2.py +++ b/src/transformers/models/owlv2/processing_owlv2.py @@ -16,15 +16,40 @@ Image/Text processor class for OWLv2 """ -from typing import List +from typing import List, Optional, Union import numpy as np -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ( + ImagesKwargs, + ProcessingKwargs, + ProcessorMixin, + Unpack, + _validate_images_text_input_order, +) +from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import is_flax_available, is_tf_available, is_torch_available +class Owlv2ImagesKwargs(ImagesKwargs, total=False): + query_images: Optional[ImageInput] + + +class Owlv2ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Owlv2ImagesKwargs + _defaults = { + "text_kwargs": { + "padding": "max_length", + }, + "images_kwargs": {}, + "common_kwargs": { + "return_tensors": "np", + }, + } + + class Owlv2Processor(ProcessorMixin): r""" Constructs an Owlv2 processor which wraps [`Owlv2ImageProcessor`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`] into @@ -41,12 +66,27 @@ class Owlv2Processor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "Owlv2ImageProcessor" tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast") + # For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. + optional_call_args = ["query_images"] def __init__(self, image_processor, tokenizer, **kwargs): super().__init__(image_processor, tokenizer) - # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.__call__ with OWLViT->OWLv2 - def __call__(self, text=None, images=None, query_images=None, padding="max_length", return_tensors="np", **kwargs): + # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.__call__ with OwlViT->Owlv2 + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + # The following is to capture `query_images` argument that may be passed as a positional argument. + # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, + # or this conversation for more context: https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 + # This behavior is only needed for backward compatibility and will be removed in future versions. + # + *args, + audio=None, + videos=None, + **kwargs: Unpack[Owlv2ProcessorKwargs], + ) -> BatchFeature: """ Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode: @@ -55,14 +95,15 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt of the above two methods for more information. Args: - text (`str`, `List[str]`, `List[List[str]]`): - The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings - (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set - `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. + + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). query_images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The query image to be prepared, one query image is expected per target image to be queried. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image @@ -73,23 +114,36 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - `'jax'`: Return JAX `jnp.ndarray` objects. + Returns: - [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **query_pixel_values** -- Pixel values of the query images to be fed to a model. Returned when `query_images` is not `None`. """ + output_kwargs = self._merge_kwargs( + Owlv2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + **self.prepare_and_validate_optional_call_args(*args), + ) + query_images = output_kwargs["images_kwargs"].pop("query_images", None) + return_tensors = output_kwargs["common_kwargs"]["return_tensors"] if text is None and query_images is None and images is None: raise ValueError( "You have to specify at least one text or query image or image. All three cannot be none." ) + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + data = {} if text is not None: if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)): - encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)] + encodings = [self.tokenizer(text, **output_kwargs["text_kwargs"])] elif isinstance(text, List) and isinstance(text[0], List): encodings = [] @@ -102,7 +156,7 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt if len(t) != max_num_queries: t = t + [" "] * (max_num_queries - len(t)) - encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs) + encoding = self.tokenizer(t, **output_kwargs["text_kwargs"]) encodings.append(encoding) else: raise TypeError("Input text should be a string, a list of strings or a nested list of strings") @@ -132,43 +186,32 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt else: raise ValueError("Target return tensor type could not be returned") - encoding = BatchEncoding() - encoding["input_ids"] = input_ids - encoding["attention_mask"] = attention_mask + data["input_ids"] = input_ids + data["attention_mask"] = attention_mask if query_images is not None: - encoding = BatchEncoding() - query_pixel_values = self.image_processor( - query_images, return_tensors=return_tensors, **kwargs - ).pixel_values - encoding["query_pixel_values"] = query_pixel_values + query_pixel_values = self.image_processor(query_images, **output_kwargs["images_kwargs"]).pixel_values + # Query images always override the text prompt + data = {"query_pixel_values": query_pixel_values} if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) - - if text is not None and images is not None: - encoding["pixel_values"] = image_features.pixel_values - return encoding - elif query_images is not None and images is not None: - encoding["pixel_values"] = image_features.pixel_values - return encoding - elif text is not None or query_images is not None: - return encoding - else: - return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) - - # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.post_process_object_detection with OWLViT->OWLv2 + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + data["pixel_values"] = image_features.pixel_values + + return BatchFeature(data=data, tensor_type=return_tensors) + + # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.post_process_object_detection with OwlViT->Owlv2 def post_process_object_detection(self, *args, **kwargs): """ - This method forwards all its arguments to [`OwlViTImageProcessor.post_process_object_detection`]. Please refer + This method forwards all its arguments to [`Owlv2ImageProcessor.post_process_object_detection`]. Please refer to the docstring of this method for more information. """ return self.image_processor.post_process_object_detection(*args, **kwargs) - # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.post_process_image_guided_detection with OWLViT->OWLv2 + # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.post_process_image_guided_detection with OwlViT->Owlv2 def post_process_image_guided_detection(self, *args, **kwargs): """ - This method forwards all its arguments to [`OwlViTImageProcessor.post_process_one_shot_object_detection`]. + This method forwards all its arguments to [`Owlv2ImageProcessor.post_process_one_shot_object_detection`]. Please refer to the docstring of this method for more information. """ return self.image_processor.post_process_image_guided_detection(*args, **kwargs) diff --git a/tests/models/owlv2/test_processor_owlv2.py b/tests/models/owlv2/test_processor_owlv2.py new file mode 100644 index 000000000000..7a7543dbb360 --- /dev/null +++ b/tests/models/owlv2/test_processor_owlv2.py @@ -0,0 +1,38 @@ +import shutil +import tempfile +import unittest + +import pytest + +from transformers import Owlv2Processor +from transformers.testing_utils import require_scipy + +from ...test_processing_common import ProcessorTesterMixin + + +@require_scipy +class Owlv2ProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = Owlv2Processor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = self.processor_class.from_pretrained("google/owlv2-base-patch16-ensemble") + processor.save_pretrained(self.tmpdirname) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def test_processor_query_images_positional(self): + processor_components = self.prepare_components() + processor = Owlv2Processor(**processor_components) + + image_input = self.prepare_image_inputs() + query_images = self.prepare_image_inputs() + + inputs = processor(None, image_input, query_images) + + self.assertListEqual(list(inputs.keys()), ["query_pixel_values", "pixel_values"]) + + # test if it raises when no input is passed + with pytest.raises(ValueError): + processor() From 115dae2091268612843c2de05b1540e273e17521 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 14 Jan 2025 21:23:12 +0000 Subject: [PATCH 3/7] nit --- src/transformers/models/owlv2/processing_owlv2.py | 1 - src/transformers/models/owlvit/processing_owlvit.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/transformers/models/owlv2/processing_owlv2.py b/src/transformers/models/owlv2/processing_owlv2.py index 14093a37ec97..bdec0b1c05df 100644 --- a/src/transformers/models/owlv2/processing_owlv2.py +++ b/src/transformers/models/owlv2/processing_owlv2.py @@ -99,7 +99,6 @@ def __call__( `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. - text (`str`, `List[str]`, `List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set diff --git a/src/transformers/models/owlvit/processing_owlvit.py b/src/transformers/models/owlvit/processing_owlvit.py index 2b73f4bc9a49..38c42affb6e4 100644 --- a/src/transformers/models/owlvit/processing_owlvit.py +++ b/src/transformers/models/owlvit/processing_owlvit.py @@ -114,7 +114,6 @@ def __call__( `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. - text (`str`, `List[str]`, `List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set From 02b713649ebdad5eae2a74797161e308b4b9409d Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 22 Jan 2025 14:40:33 +0000 Subject: [PATCH 4/7] add positional arg test owlvit --- tests/models/owlvit/test_processor_owlvit.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/models/owlvit/test_processor_owlvit.py b/tests/models/owlvit/test_processor_owlvit.py index f31dbaf9fbcc..5f99a2275f1a 100644 --- a/tests/models/owlvit/test_processor_owlvit.py +++ b/tests/models/owlvit/test_processor_owlvit.py @@ -232,6 +232,21 @@ def test_processor_case2(self): with pytest.raises(ValueError): processor() + def test_processor_query_images_positional(self): + processor_components = self.prepare_components() + processor = OwlViTProcessor(**processor_components) + + image_input = self.prepare_image_inputs() + query_images = self.prepare_image_inputs() + + inputs = processor(None, image_input, query_images) + + self.assertListEqual(list(inputs.keys()), ["query_pixel_values", "pixel_values"]) + + # test if it raises when no input is passed + with pytest.raises(ValueError): + processor() + def test_tokenizer_decode(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer() From 6f23b7f314cc0fb9159e1670cb4b068acbddbf0e Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 23 Jan 2025 15:08:24 +0000 Subject: [PATCH 5/7] run-slow: owlvit, owlv2 From 612ab9e3add1c389de2b50a6739eae9812749a75 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 23 Jan 2025 15:40:50 +0000 Subject: [PATCH 6/7] run-slow: owlvit, owlv2 From 7b998e058d888cd56c383f665f6b3d67c4299df7 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 13 Feb 2025 22:22:12 +0000 Subject: [PATCH 7/7] remove one letter variable --- src/transformers/models/owlv2/processing_owlv2.py | 10 +++++----- src/transformers/models/owlvit/processing_owlvit.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/owlv2/processing_owlv2.py b/src/transformers/models/owlv2/processing_owlv2.py index 63abd28167a5..664c63ffee0e 100644 --- a/src/transformers/models/owlv2/processing_owlv2.py +++ b/src/transformers/models/owlv2/processing_owlv2.py @@ -153,14 +153,14 @@ def __call__( encodings = [] # Maximum number of queries across batch - max_num_queries = max([len(t) for t in text]) + max_num_queries = max([len(text_single) for text_single in text]) # Pad all batch samples to max number of text queries - for t in text: - if len(t) != max_num_queries: - t = t + [" "] * (max_num_queries - len(t)) + for text_single in text: + if len(text_single) != max_num_queries: + text_single = text_single + [" "] * (max_num_queries - len(text_single)) - encoding = self.tokenizer(t, **output_kwargs["text_kwargs"]) + encoding = self.tokenizer(text_single, **output_kwargs["text_kwargs"]) encodings.append(encoding) else: raise TypeError("Input text should be a string, a list of strings or a nested list of strings") diff --git a/src/transformers/models/owlvit/processing_owlvit.py b/src/transformers/models/owlvit/processing_owlvit.py index 1b68142ee959..98c24747b468 100644 --- a/src/transformers/models/owlvit/processing_owlvit.py +++ b/src/transformers/models/owlvit/processing_owlvit.py @@ -167,14 +167,14 @@ def __call__( encodings = [] # Maximum number of queries across batch - max_num_queries = max([len(t) for t in text]) + max_num_queries = max([len(text_single) for text_single in text]) # Pad all batch samples to max number of text queries - for t in text: - if len(t) != max_num_queries: - t = t + [" "] * (max_num_queries - len(t)) + for text_single in text: + if len(text_single) != max_num_queries: + text_single = text_single + [" "] * (max_num_queries - len(text_single)) - encoding = self.tokenizer(t, **output_kwargs["text_kwargs"]) + encoding = self.tokenizer(text_single, **output_kwargs["text_kwargs"]) encodings.append(encoding) else: raise TypeError("Input text should be a string, a list of strings or a nested list of strings")