Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/transformers/models/glm4v/configuration_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/glm4v/image_processing_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,11 +454,11 @@ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=Non

factor = patch_size * merge_size
resized_height, resized_width = smart_resize(
t=self.temporal_patch_size,
num_frames=self.temporal_patch_size,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍

height=height,
width=width,
factor=factor,
t_factor=self.temporal_patch_size,
temporal_factor=self.temporal_patch_size,
)
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
return grid_h * grid_w
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/glm4v/modeling_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
Expand Down Expand Up @@ -769,6 +768,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)

hidden_states = self.post_self_attn_layernorm(hidden_states)
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/models/glm4v/modular_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from typing import Callable, Optional, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -838,6 +838,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)

hidden_states = self.post_self_attn_layernorm(hidden_states)
Expand Down Expand Up @@ -1582,6 +1583,7 @@ class Glm4vProcessorKwargs(Qwen2_5_VLProcessorKwargs):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
}

Expand Down Expand Up @@ -1723,9 +1725,15 @@ def __call__(

text[i] = text[i].replace("<|placeholder|>", self.image_token)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])

if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)


Expand Down
10 changes: 9 additions & 1 deletion src/transformers/models/glm4v/processing_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union

import numpy as np

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
Expand All @@ -44,6 +45,7 @@ class Glm4vProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
}

Expand Down Expand Up @@ -200,9 +202,15 @@ def __call__(

text[i] = text[i].replace("<|placeholder|>", self.image_token)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])

if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)

def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
Expand Down
61 changes: 58 additions & 3 deletions src/transformers/models/perception_lm/processing_perception_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

from typing import Iterable, Union

import numpy as np

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, get_image_size, to_numpy_array
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
from ...video_utils import VideoInput
Expand All @@ -32,6 +34,7 @@ class PerceptionLMProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
}

Expand Down Expand Up @@ -157,9 +160,17 @@ def __call__(
prompt_strings.append(sample)

return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image", "video"])
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)

if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()

return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)

def _expand_media_tokens(self, sample, media_token: str, media_iter: Iterable):
media_count = sample.count(media_token)
Expand All @@ -183,6 +194,50 @@ def _expand_media_tokens(self, sample, media_token: str, media_iter: Iterable):
sample += sample_splits[-1]
return sample

def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.

Args:
image_sizes (`list[list[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.

Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""

vision_data = {}
if image_sizes is not None:
images_kwargs = PerceptionLMProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
tile_size = images_kwargs.get("tile_size", None) or self.image_processor.tile_size

num_image_tokens = []
num_image_patches = []
for height, width in image_sizes:
if self.image_processor.vision_input_type == "thumb+tile":
aspect_ratio = self.image_processor._fit_image_to_canvas(
img_width=width, img_height=height, tile_size=tile_size
)
if aspect_ratio is None:
aspect_ratio = self.image_processor._find_closest_aspect_ratio(
img_width=width, img_height=height, tile_size=tile_size
)
num_tiles = aspect_ratio[0] * aspect_ratio[1] + 1 # base image and tiles
else:
num_tiles = 1

num_image_tokens.append(
(tile_size // self.patch_size // self.pooling_ratio)
* (tile_size // self.patch_size // self.pooling_ratio)
* num_tiles
)
num_image_patches.append(num_tiles)

vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)

def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PerceptionLMTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
Expand Down