From e88cf2ed17034f340a520c9875e3047c5c0561b5 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 4 Dec 2025 19:40:28 +0000 Subject: [PATCH] fallback empty patch batch --- fast_llm/models/multimodal/model.py | 30 ++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index f8251e212..a5dd08306 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -5,6 +5,7 @@ from fast_llm.core.distributed import all_gather_scalar from fast_llm.data.sample.language_model import LanguageModelBatch +from fast_llm.data.sample.patch import PatchBatch from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner @@ -150,6 +151,30 @@ def preprocess_meta( return preprocessed_meta + def _get_empty_image_patches(self, tokens: torch.Tensor, kwargs: dict[str, typing.Any]) -> PatchBatch: + patch_embeddings_config = self._config.vision_encoder.embeddings + sequence_first = kwargs[AttentionKwargs.sequence_first] + device = tokens.device + dtype = self._distributed.config.compute_dtype.torch + return PatchBatch( + patches=torch.empty( + ( + 0, + patch_embeddings_config.input_channels, + patch_embeddings_config.patch_height, + patch_embeddings_config.patch_width, + ), + device=device, + dtype=dtype, + ), + sample_map=torch.empty(0, device=device, dtype=torch.int32), + token_map=torch.empty(0, device=device, dtype=torch.int32), + positions=torch.empty((0, 2), device=device, dtype=torch.int32), + num_samples=tokens.shape[1] if sequence_first else tokens.shape[0], + sample_size=kwargs[AttentionKwargs.sequence_q_dim].size, + lengths=[], + ) + def preprocess_batch( self, batch: LanguageModelBatch, @@ -172,7 +197,10 @@ def preprocess_batch( # TODO: Handle earlier. tokens_end = kwargs[AttentionKwargs.sequence_k_dim].size tokens_begin = tokens_end - kwargs[AttentionKwargs.sequence_q_dim].size - cropped_image_patches = batch.image_patches.crop(tokens_begin, tokens_end) + if batch.image_patches is None: + cropped_image_patches = self._get_empty_image_patches(tokens, kwargs) + else: + cropped_image_patches = batch.image_patches.crop(tokens_begin, tokens_end) sequence_length = tokens.shape[:2].numel() pad_size = sequence_length - cropped_image_patches.patches.size(0)