Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/model_doc/vitpose.md
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ Refer to resources below to learn more about using ViTPose.
- preprocess
- post_process_pose_estimation

## VitPoseImageProcessorFast

[[autodoc]] VitPoseImageProcessorFast
- preprocess
- post_process_pose_estimation

## VitPoseConfig

[[autodoc]] VitPoseConfig
Expand Down
149 changes: 147 additions & 2 deletions src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,9 +497,10 @@ def _process_image(

# Infer the channel dimension format if not provided
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
input_data_format = self.infer_channel_dimension_format_fast(image)

if input_data_format == ChannelDimension.LAST:
# Only convert to channels_first if we need to and it's not already in that format
if input_data_format == ChannelDimension.LAST and image.shape[-1] in [1, 3, 4]:
# We force the channel dimension to be first for torch tensors as this is what torchvision expects.
image = image.permute(2, 0, 1).contiguous()

Expand Down Expand Up @@ -733,3 +734,147 @@ def to_dict(self):
encoder_dict.pop("_valid_processor_keys", None)
encoder_dict.pop("_valid_kwargs_names", None)
return encoder_dict

def to_channel_dimension_format_fast(
self,
image: "torch.Tensor",
channel_dim: Union[str, ChannelDimension],
input_channel_dim: Optional[Union[str, ChannelDimension]] = None,
) -> "torch.Tensor":
"""
Convert the image to the target channel dimension format using PyTorch operations.

Args:
image (`torch.Tensor`): Image tensor to convert.
channel_dim (`Union[str, ChannelDimension]`): Target channel dimension format.
input_channel_dim (`Union[str, ChannelDimension]`, *optional*): Input channel dimension format.

Returns:
`torch.Tensor`: Image with the target channel dimension format.
"""
if input_channel_dim is None:
input_channel_dim = infer_channel_dimension_format(image)

if input_channel_dim == channel_dim:
return image

if channel_dim == ChannelDimension.FIRST:
if image.shape[-1] == 3: # (H, W, C) -> (C, H, W)
return image.permute(2, 0, 1)
elif image.shape[0] == 3: # (C, H, W) - already correct
return image
else: # (H, C, W) -> (C, H, W)
return image.permute(1, 0, 2)
elif channel_dim == ChannelDimension.LAST:
if image.shape[0] == 3: # (C, H, W) -> (H, W, C)
return image.permute(1, 2, 0)
elif image.shape[-1] == 3: # (H, W, C) - already correct
return image
else: # (H, C, W) -> (H, W, C)
return image.permute(0, 2, 1)
else:
raise ValueError(f"Unsupported channel dimension: {channel_dim}")

def is_scaled_image_fast(self, image: "torch.Tensor") -> bool:
"""
Check if the image is already scaled (pixel values in [0, 1]) using PyTorch operations.

Args:
image (`torch.Tensor`): Image tensor to check.

Returns:
`bool`: True if the image is already scaled, False otherwise.
"""
if image.dtype == torch.float32 or image.dtype == torch.float64:
return image.min() >= 0.0 and image.max() <= 1.0
elif image.dtype == torch.uint8:
return False
else:
# For other dtypes, assume they're not scaled
return False

def valid_images_fast(self, images: list["torch.Tensor"]) -> bool:
"""
Check if all images in the list are valid PyTorch tensors.

Args:
images (`list[torch.Tensor]`): List of image tensors to validate.

Returns:
`bool`: True if all images are valid, False otherwise.
"""
if not images:
return False

for image in images:
if not torch.is_tensor(image):
return False
if image.ndim not in [2, 3]:
return False
if image.ndim == 3 and image.shape[0] not in [1, 3, 4] and image.shape[-1] not in [1, 3, 4]:
return False

return True

def make_list_of_images_fast(self, images: ImageInput) -> list["torch.Tensor"]:
"""
Convert various image inputs to a list of PyTorch tensors.

Args:
images (`ImageInput`): Images to convert.

Returns:
`list[torch.Tensor]`: List of PyTorch tensor images.
"""
if isinstance(images, (list, tuple)):
# Convert each image to tensor if needed
tensor_images = []
for img in images:
if torch.is_tensor(img):
tensor_images.append(img)
elif hasattr(img, "shape") and len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]:
# For numpy arrays with channels_last format, convert directly to tensor
tensor_images.append(torch.from_numpy(img).contiguous())
else:
# Convert PIL, etc. to tensor
tensor_images.append(self._process_image(img))
return tensor_images
else:
# Single image
if torch.is_tensor(images):
return [images]
elif hasattr(images, "shape") and len(images.shape) == 3 and images.shape[-1] in [1, 3, 4]:
# For numpy arrays with channels_last format, convert directly to tensor
processed = torch.from_numpy(images).contiguous()
logger.debug(f"make_list_of_images_fast: input shape {images.shape}, output shape {processed.shape}")
return [processed]
else:
processed = self._process_image(images)
logger.debug(
f"make_list_of_images_fast: input shape {getattr(images, 'shape', 'N/A')}, output shape {processed.shape}"
)
return [processed]

def infer_channel_dimension_format_fast(self, image: "torch.Tensor") -> ChannelDimension:
"""
Infer the channel dimension format of a PyTorch tensor image.

Args:
image (`torch.Tensor`): Image tensor.

Returns:
`ChannelDimension`: The inferred channel dimension format.
"""
if image.ndim == 2:
return ChannelDimension.FIRST # Single channel image

if image.ndim == 3:
if image.shape[0] in [1, 3, 4]:
return ChannelDimension.FIRST # (C, H, W)
elif image.shape[-1] in [1, 3, 4]:
return ChannelDimension.LAST # (H, W, C)
else:
# Ambiguous case, default to first
return ChannelDimension.FIRST

raise ValueError(f"Unsupported image shape: {image.shape}")
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@
("vit_mae", ("ViTImageProcessor", "ViTImageProcessorFast")),
("vit_msn", ("ViTImageProcessor", "ViTImageProcessorFast")),
("vitmatte", ("VitMatteImageProcessor", "VitMatteImageProcessorFast")),
("vitpose", ("VitPoseImageProcessor", "VitPoseImageProcessorFast")),
("xclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("yolos", ("YolosImageProcessor", "YolosImageProcessorFast")),
("zoedepth", ("ZoeDepthImageProcessor", "ZoeDepthImageProcessorFast")),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/vitpose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
if TYPE_CHECKING:
from .configuration_vitpose import *
from .image_processing_vitpose import *
from .image_processing_vitpose_fast import *
from .modeling_vitpose import *
else:
import sys
Expand Down
Loading