From 0956ec75d007bbd738bb09c6a8d5c3d3de147ad2 Mon Sep 17 00:00:00 2001 From: cyr0930 Date: Fri, 8 Aug 2025 02:28:33 +0000 Subject: [PATCH 1/3] [fix] llava onevision batch inference --- docs/source/en/_toctree.yml | 2 +- docs/source/en/model_doc/llava_onevision.md | 19 ++++++-------- .../image_processing_llava_onevision.py | 6 +++-- .../modeling_llava_onevision.py | 1 - .../modular_llava_onevision.py | 1 - .../processing_llava_onevision.py | 5 ++-- .../test_modeling_llava_onevision.py | 25 +++++++++++-------- 7 files changed, 31 insertions(+), 28 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 778d4255e6df..9a7c9df26962 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1046,7 +1046,7 @@ - local: model_doc/llama4 title: Llama4 - local: model_doc/llava - title: Llava + title: LLaVA - local: model_doc/llava_next title: LLaVA-NeXT - local: model_doc/llava_next_video diff --git a/docs/source/en/model_doc/llava_onevision.md b/docs/source/en/model_doc/llava_onevision.md index a8b63c9016d9..4d15e2a621d5 100644 --- a/docs/source/en/model_doc/llava_onevision.md +++ b/docs/source/en/model_doc/llava_onevision.md @@ -38,7 +38,7 @@ yielding new emerging capabilities. In particular, strong video understanding an cross-scenario capabilities are demonstrated through task transfer from images to videos.* -drawing LLaVA-OneVision architecture. Taken from the original paper. @@ -165,20 +165,20 @@ conversation_1 = [ "content": [ {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, {"type": "text", "text": "What is shown in this image?"}, - ], + ], }, { "role": "assistant", "content": [ {"type": "text", "text": "There is a red stop sign in the image."}, - ], + ], }, { "role": "user", "content": [ {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}, {"type": "text", "text": "What about this image? How many cats do you see?"}, - ], + ], }, ] @@ -188,7 +188,7 @@ conversation_2 = [ "content": [ {"type": "image", "url": "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"}, {"type": "text", "text": "What is shown in this image?"}, - ], + ], }, ] @@ -198,13 +198,14 @@ inputs = processor.apply_chat_template( tokenize=True, return_dict=True, padding=True, - return_tensors="pt" + padding_side="left", + return_tensors="pt", ).to(model.device, torch.float16) # Generate generate_ids = model.generate(**inputs, max_new_tokens=30) processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) -['user\n\nWhat is shown in this image?\nassistant\nThere is a red stop sign in the image.\nuser\n\nWhat about this image? How many cats do you see?\nassistant\ntwo', 'user\n\nWhat is shown in this image?\nassistant\n'] +['user\n\nWhat is shown in this image?\nassistant\nThere is a red stop sign in the image.\nuser\n\nWhat about this image? How many cats do you see?\nassistant\ntwo', 'user\n\nWhat is shown in this image?\nassistant\nThe image shows a whimsical scene of a snowman sitting by a campfire. The snowman is anthropomorphized, wearing a hat and'] ``` ### Video inference @@ -312,10 +313,6 @@ model = LlavaOnevisionForConditionalGeneration.from_pretrained( [[autodoc]] LlavaOnevisionVideoProcessor -## LlavaOnevisionVideoProcessor - -[[autodoc]] LlavaOnevisionVideoProcessor - ## LlavaOnevisionModel [[autodoc]] LlavaOnevisionModel diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py index f7f108729eac..40687456174f 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py @@ -680,8 +680,10 @@ def preprocess( do_pad = do_pad if do_pad is not None else self.do_pad do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)): - # if the first element is a list, we assume that all elements are lists + if isinstance(images, (tuple, list)) and any(isinstance(image, (tuple, list)) for image in images): + for i, image in enumerate(images): + if not isinstance(image, (tuple, list)): + images[i] = [image] batch_num_images = [len(x) for x in images] elif isinstance(images, (tuple, list)): # treat this as a single-image case for backward compatibility diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index d63cde2b5789..0e2e5434dab2 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -394,7 +394,6 @@ def pack_image_features(self, image_features, image_sizes, image_newline=None, v image_feature = image_feature[0] if image_newline is not None: image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) - image_feature = image_feature.flatten(0, 1) new_image_features.append(image_feature) feature_lens.append(image_feature.size(0)) feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device) diff --git a/src/transformers/models/llava_onevision/modular_llava_onevision.py b/src/transformers/models/llava_onevision/modular_llava_onevision.py index 14a8f39915f0..ef32f4b8625f 100644 --- a/src/transformers/models/llava_onevision/modular_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modular_llava_onevision.py @@ -320,7 +320,6 @@ def pack_image_features(self, image_features, image_sizes, image_newline=None, v image_feature = image_feature[0] if image_newline is not None: image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) - image_feature = image_feature.flatten(0, 1) new_image_features.append(image_feature) feature_lens.append(image_feature.size(0)) feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device) diff --git a/src/transformers/models/llava_onevision/processing_llava_onevision.py b/src/transformers/models/llava_onevision/processing_llava_onevision.py index 5a17455a5247..017b4887daab 100644 --- a/src/transformers/models/llava_onevision/processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/processing_llava_onevision.py @@ -214,15 +214,16 @@ def _expand_image_tokens( prompt_strings = [] max_num_vision_tokens = 0 for sample in text: + num_images = next(batch_num_images) # should consume iterable if special_token in sample: - is_multi_image = next(batch_num_images) != 1 + is_multi_image = num_images != 1 else: is_multi_image = False while special_token in sample: + original_size = next(image_sizes) # should consume iterable if is_multi_image: num_image_tokens = self.num_image_tokens + 1 # one for image_newline else: - original_size = next(image_sizes) if not isinstance(original_size, (list, tuple)): # cast to list to avoid numerical precision errors when calculating unpadding original_size = original_size.tolist() diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py index 56fe8dcbb293..33532eea95c7 100644 --- a/tests/models/llava_onevision/test_modeling_llava_onevision.py +++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py @@ -446,20 +446,25 @@ def test_small_model_integration_test_multi_image_nested(self): url = "https://www.ilankelman.org/stopsigns/australia.jpg" image = Image.open(requests.get(url, stream=True).raw) - prompt = ( - "user\n\nWhat is the difference between these images?<|im_end|>\n<|im_start|>assistant\n" - ) - images_nested = [[self.image, image]] - inputs = self.processor(text=prompt, images=images_nested, return_tensors="pt").to(torch_device, torch.float16) + prompts = [ + self.prompt_image, + "user\n\nWhat is the difference between these images?<|im_end|>\n<|im_start|>assistant\n", + self.prompt_image, + ] + images_nested = [self.image, [image, self.image], self.image] + inputs = self.processor( + text=prompts, + images=images_nested, + return_tensors="pt", + padding=True, + ).to(torch_device, torch.float16) # verify generation output = model.generate(**inputs, max_new_tokens=40) - EXPECTED_DECODED_TEXT = "user\n\nWhat is the difference between these images?\nassistant\nThe first image is a radar chart showing the performance of different models in a specific task, while the second image is a street scene with a stop sign in the foreground." # fmt: skip + EXPECTED_DECODED_TEXT = ["user\n\nWhat do you see in this image?\nassistant\nThe image is a radar chart that compares the performance of different models in a specific task, likely related to natural language processing or machine learning. The chart is divided into several axes, each representing a different", "user\n\nWhat is the difference between these images?\nassistant\nThe first image shows a stop sign with a traditional Chinese architectural background, while the second image displays a radar chart with various algorithms and models, including BLIP-2, InstructBLIP, Q", "user\n\nWhat do you see in this image?\nassistant\nThe image is a radar chart that compares the performance of different models in a specific task, likely related to natural language processing or machine learning. The chart is divided into several axes, each representing a different"] # fmt: skip + DECODED_TEXT = self.processor.batch_decode(output, skip_special_tokens=True) - self.assertEqual( - self.processor.decode(output[0], skip_special_tokens=True), - EXPECTED_DECODED_TEXT, - ) + self.assertListEqual(DECODED_TEXT, EXPECTED_DECODED_TEXT) @slow @require_bitsandbytes From 68bf594ae90191b3e3d1a2782aed264ec82fcef6 Mon Sep 17 00:00:00 2001 From: cyr0930 Date: Fri, 8 Aug 2025 05:58:22 +0000 Subject: [PATCH 2/3] style --- .../models/llava_onevision/processing_llava_onevision.py | 4 ++-- tests/models/llava_onevision/test_modeling_llava_onevision.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/llava_onevision/processing_llava_onevision.py b/src/transformers/models/llava_onevision/processing_llava_onevision.py index 017b4887daab..bdfaf9328b30 100644 --- a/src/transformers/models/llava_onevision/processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/processing_llava_onevision.py @@ -214,13 +214,13 @@ def _expand_image_tokens( prompt_strings = [] max_num_vision_tokens = 0 for sample in text: - num_images = next(batch_num_images) # should consume iterable + num_images = next(batch_num_images) # should consume iterable if special_token in sample: is_multi_image = num_images != 1 else: is_multi_image = False while special_token in sample: - original_size = next(image_sizes) # should consume iterable + original_size = next(image_sizes) # should consume iterable if is_multi_image: num_image_tokens = self.num_image_tokens + 1 # one for image_newline else: diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py index 33532eea95c7..86314716a679 100644 --- a/tests/models/llava_onevision/test_modeling_llava_onevision.py +++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py @@ -461,7 +461,7 @@ def test_small_model_integration_test_multi_image_nested(self): # verify generation output = model.generate(**inputs, max_new_tokens=40) - EXPECTED_DECODED_TEXT = ["user\n\nWhat do you see in this image?\nassistant\nThe image is a radar chart that compares the performance of different models in a specific task, likely related to natural language processing or machine learning. The chart is divided into several axes, each representing a different", "user\n\nWhat is the difference between these images?\nassistant\nThe first image shows a stop sign with a traditional Chinese architectural background, while the second image displays a radar chart with various algorithms and models, including BLIP-2, InstructBLIP, Q", "user\n\nWhat do you see in this image?\nassistant\nThe image is a radar chart that compares the performance of different models in a specific task, likely related to natural language processing or machine learning. The chart is divided into several axes, each representing a different"] # fmt: skip + EXPECTED_DECODED_TEXT = ["user\n\nWhat do you see in this image?\nassistant\nThe image is a radar chart that compares the performance of different models in a specific task, likely related to natural language processing or machine learning. The chart is divided into several axes, each representing a different", "user\n\nWhat is the difference between these images?\nassistant\nThe first image shows a stop sign with a traditional Chinese architectural background, while the second image displays a radar chart with various algorithms and models, including BLIP-2, InstructBLIP, Q", "user\n\nWhat do you see in this image?\nassistant\nThe image is a radar chart that compares the performance of different models in a specific task, likely related to natural language processing or machine learning. The chart is divided into several axes, each representing a different"] # fmt: skip DECODED_TEXT = self.processor.batch_decode(output, skip_special_tokens=True) self.assertListEqual(DECODED_TEXT, EXPECTED_DECODED_TEXT) From 63d203d8ec323edcb98139b9ac711872607eb5ef Mon Sep 17 00:00:00 2001 From: cyr0930 Date: Mon, 11 Aug 2025 08:26:51 +0000 Subject: [PATCH 3/3] cannot pass inconsistent list & handle text-only case --- .../llava_onevision/image_processing_llava_onevision.py | 7 +++---- .../models/llava_onevision/processing_llava_onevision.py | 2 +- .../llava_onevision/test_modeling_llava_onevision.py | 6 +++--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py index 40687456174f..8729c7444d29 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py @@ -680,10 +680,9 @@ def preprocess( do_pad = do_pad if do_pad is not None else self.do_pad do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - if isinstance(images, (tuple, list)) and any(isinstance(image, (tuple, list)) for image in images): - for i, image in enumerate(images): - if not isinstance(image, (tuple, list)): - images[i] = [image] + if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)): + # if the first element is a list, we assume that all elements are lists + images = [x for x in images if x] # handle text-only case batch_num_images = [len(x) for x in images] elif isinstance(images, (tuple, list)): # treat this as a single-image case for backward compatibility diff --git a/src/transformers/models/llava_onevision/processing_llava_onevision.py b/src/transformers/models/llava_onevision/processing_llava_onevision.py index bdfaf9328b30..7deadb9131b1 100644 --- a/src/transformers/models/llava_onevision/processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/processing_llava_onevision.py @@ -214,8 +214,8 @@ def _expand_image_tokens( prompt_strings = [] max_num_vision_tokens = 0 for sample in text: - num_images = next(batch_num_images) # should consume iterable if special_token in sample: + num_images = next(batch_num_images) # should consume iterable is_multi_image = num_images != 1 else: is_multi_image = False diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py index 86314716a679..9fd4845513cb 100644 --- a/tests/models/llava_onevision/test_modeling_llava_onevision.py +++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py @@ -447,11 +447,11 @@ def test_small_model_integration_test_multi_image_nested(self): url = "https://www.ilankelman.org/stopsigns/australia.jpg" image = Image.open(requests.get(url, stream=True).raw) prompts = [ - self.prompt_image, + "user\nTell me about the french revolution.<|im_end|>\n<|im_start|>assistant\n", # text-only case "user\n\nWhat is the difference between these images?<|im_end|>\n<|im_start|>assistant\n", self.prompt_image, ] - images_nested = [self.image, [image, self.image], self.image] + images_nested = [[], [image, self.image], [self.image]] inputs = self.processor( text=prompts, images=images_nested, @@ -461,7 +461,7 @@ def test_small_model_integration_test_multi_image_nested(self): # verify generation output = model.generate(**inputs, max_new_tokens=40) - EXPECTED_DECODED_TEXT = ["user\n\nWhat do you see in this image?\nassistant\nThe image is a radar chart that compares the performance of different models in a specific task, likely related to natural language processing or machine learning. The chart is divided into several axes, each representing a different", "user\n\nWhat is the difference between these images?\nassistant\nThe first image shows a stop sign with a traditional Chinese architectural background, while the second image displays a radar chart with various algorithms and models, including BLIP-2, InstructBLIP, Q", "user\n\nWhat do you see in this image?\nassistant\nThe image is a radar chart that compares the performance of different models in a specific task, likely related to natural language processing or machine learning. The chart is divided into several axes, each representing a different"] # fmt: skip + EXPECTED_DECODED_TEXT = ["user\nTell me about the french revolution.\nassistant\nThe French Revolution! A pivotal event in modern history that had a profound impact on the course of Western civilization. Here's a brief overview:\n\n**Background**\n\nIn the late 18th century,", "user\n\nWhat is the difference between these images?\nassistant\nThe first image shows a stop sign with a traditional Chinese architectural background, while the second image displays a radar chart with various algorithms and models, including BLIP-2, InstructBLIP, Q", "user\n\nWhat do you see in this image?\nassistant\nThe image is a radar chart that compares the performance of different models in a specific task, likely related to natural language processing or machine learning. The chart is divided into several axes, each representing a different"] # fmt: skip DECODED_TEXT = self.processor.batch_decode(output, skip_special_tokens=True) self.assertListEqual(DECODED_TEXT, EXPECTED_DECODED_TEXT)