diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index c0cbc7111f4b..13d4a1ab338b 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -422,7 +422,12 @@ class XCLIPPreTrainedModel(PreTrainedModel): config: XCLIPConfig base_model_prefix = "x_clip" input_modalities = ("image", "text") - _no_split_modules = ["XCLIPTextEmbeddings", "XCLIPEncoderLayer", "XCLIPVisionEmbeddings"] + _no_split_modules = [ + "XCLIPTextEmbeddings", + "XCLIPEncoderLayer", + "XCLIPVisionEmbeddings", + "XCLIPVisionEncoderLayer", + ] supports_gradient_checkpointing = True _supports_sdpa = True diff --git a/src/transformers/models/x_clip/modular_x_clip.py b/src/transformers/models/x_clip/modular_x_clip.py index 9d76e97430d1..5980e8b68e07 100644 --- a/src/transformers/models/x_clip/modular_x_clip.py +++ b/src/transformers/models/x_clip/modular_x_clip.py @@ -173,6 +173,12 @@ def forward( class XCLIPPreTrainedModel(CLIPPreTrainedModel): config: XCLIPConfig base_model_prefix = "x_clip" + _no_split_modules = [ + "XCLIPTextEmbeddings", + "XCLIPEncoderLayer", + "XCLIPVisionEmbeddings", + "XCLIPVisionEncoderLayer", + ] _can_record_outputs = { "hidden_states": [XCLIPEncoderLayer, XCLIPVisionEncoderLayer], "attentions": OutputRecorder(XCLIPAttention, layer_name="self_attn", index=1), diff --git a/src/transformers/models/x_clip/processing_x_clip.py b/src/transformers/models/x_clip/processing_x_clip.py index d6b9fcf32736..57ed01f99506 100644 --- a/src/transformers/models/x_clip/processing_x_clip.py +++ b/src/transformers/models/x_clip/processing_x_clip.py @@ -25,5 +25,12 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): super().__init__(image_processor, tokenizer) self.video_processor = self.image_processor + def __call__(self, images=None, text=None, videos=None, **kwargs): + # X-CLIP uses the image_processor for video frames. Map videos to images + # so the base class processes them through image_processor. + if videos is not None and images is None: + images = videos + return super().__call__(images=images, text=text, **kwargs) + __all__ = ["XCLIPProcessor"] diff --git a/tests/models/x_clip/test_modeling_x_clip.py b/tests/models/x_clip/test_modeling_x_clip.py index 539ab98a479b..997736901f3a 100644 --- a/tests/models/x_clip/test_modeling_x_clip.py +++ b/tests/models/x_clip/test_modeling_x_clip.py @@ -172,6 +172,30 @@ def test_eager_matches_sdpa_inference( ): pass + @unittest.skip(reason="X-CLIP needs batch size to match frames, can't crop and create new dummy inputs") + def test_flash_attn_2_inference_equivalence(self): + pass + + @unittest.skip(reason="X-CLIP needs batch size to match frames, can't crop and create new dummy inputs") + def test_flash_attn_2_inference_equivalence_right_padding(self): + pass + + @unittest.skip(reason="X-CLIP needs batch size to match frames, can't crop and create new dummy inputs") + def test_flash_attn_3_inference_equivalence(self): + pass + + @unittest.skip(reason="X-CLIP needs batch size to match frames, can't crop and create new dummy inputs") + def test_flash_attn_3_inference_equivalence_right_padding(self): + pass + + @unittest.skip(reason="X-CLIP needs batch size to match frames, can't crop and create new dummy inputs") + def test_flash_attn_4_inference_equivalence(self): + pass + + @unittest.skip(reason="X-CLIP needs batch size to match frames, can't crop and create new dummy inputs") + def test_flash_attn_4_inference_equivalence_right_padding(self): + pass + def test_model_get_set_embeddings(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -561,6 +585,10 @@ def test_model_get_set_embeddings(self): def test_feed_forward_chunking(self): pass + @unittest.skip(reason="Does not work on the tiny model as we keep hitting edge cases.") + def test_model_parallelism(self): + pass + def test_load_vision_text_config(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index ae502023e7d9..909ddc5bc055 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3361,6 +3361,8 @@ def _get_output_logits(outputs): return outputs.decoder_hidden_states[-1] elif "logits_per_image" in outputs: return outputs.logits_per_image + elif "logits_per_video" in outputs: + return outputs.logits_per_video else: return outputs.logits