Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion fast_llm/models/multimodal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from fast_llm.core.distributed import all_gather_scalar
from fast_llm.data.sample.language_model import LanguageModelBatch
from fast_llm.data.sample.patch import PatchBatch
from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim
from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames, PhaseType
from fast_llm.engine.inference.runner import InferenceRunner
Expand Down Expand Up @@ -150,6 +151,30 @@ def preprocess_meta(

return preprocessed_meta

def _get_empty_image_patches(self, tokens: torch.Tensor, kwargs: dict[str, typing.Any]) -> PatchBatch:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably go in preprocessing/image_patch. Also it's very similar to ImagePatchConfig.get_patches_from_images, maybe it can be reused.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey thanks @jlamypoirier ! Yeh, this is about using test-only data for multimodal model. Are you planing to address it in #402? I am fine with creating those in the dataset instead of the model

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add it myself, it's not much effort and will affect the PR

patch_embeddings_config = self._config.vision_encoder.embeddings
sequence_first = kwargs[AttentionKwargs.sequence_first]
device = tokens.device
dtype = self._distributed.config.compute_dtype.torch
return PatchBatch(
patches=torch.empty(
(
0,
patch_embeddings_config.input_channels,
patch_embeddings_config.patch_height,
patch_embeddings_config.patch_width,
),
device=device,
dtype=dtype,
),
sample_map=torch.empty(0, device=device, dtype=torch.int32),
token_map=torch.empty(0, device=device, dtype=torch.int32),
positions=torch.empty((0, 2), device=device, dtype=torch.int32),
num_samples=tokens.shape[1] if sequence_first else tokens.shape[0],
sample_size=kwargs[AttentionKwargs.sequence_q_dim].size,
lengths=[],
)

def preprocess_batch(
self,
batch: LanguageModelBatch,
Expand All @@ -172,7 +197,10 @@ def preprocess_batch(
# TODO: Handle earlier.
tokens_end = kwargs[AttentionKwargs.sequence_k_dim].size
tokens_begin = tokens_end - kwargs[AttentionKwargs.sequence_q_dim].size
cropped_image_patches = batch.image_patches.crop(tokens_begin, tokens_end)
if batch.image_patches is None:
cropped_image_patches = self._get_empty_image_patches(tokens, kwargs)
else:
cropped_image_patches = batch.image_patches.crop(tokens_begin, tokens_end)

sequence_length = tokens.shape[:2].numel()
pad_size = sequence_length - cropped_image_patches.patches.size(0)
Expand Down
Loading