Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
Expand All @@ -42,7 +43,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand Down Expand Up @@ -1504,6 +1504,7 @@ def test_kernel_video(self):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
def test_functional(self, make_input):
Expand All @@ -1519,9 +1520,16 @@ def test_functional(self, make_input):
(F.affine_mask, tv_tensors.Mask),
(F.affine_video, tv_tensors.Video),
(F.affine_keypoints, tv_tensors.KeyPoints),
pytest.param(
F._geometry._affine_image_cvcuda,
None,
marks=pytest.mark.needs_cvcuda,
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._geometry._affine_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.affine, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
Expand All @@ -1534,6 +1542,7 @@ def test_functional_signature(self, kernel, input_type):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
Expand All @@ -1551,8 +1560,17 @@ def test_transform(self, make_input, device):
"interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
)
@pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
def test_functional_image_correctness(self, angle, translate, scale, shear, center, interpolation, fill):
image = make_image(dtype=torch.uint8, device="cpu")
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
def test_functional_image_correctness(
self, angle, translate, scale, shear, center, interpolation, fill, make_input
):
image = make_input(dtype=torch.uint8, device="cpu")

fill = adapt_fill(fill, dtype=torch.uint8)

Expand All @@ -1566,6 +1584,11 @@ def test_functional_image_correctness(self, angle, translate, scale, shear, cent
interpolation=interpolation,
fill=fill,
)

if make_input is make_image_cvcuda:
actual = F.cvcuda_to_tensor(actual)[0].cpu()
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(
F.affine(
F.to_pil_image(image),
Expand All @@ -1580,16 +1603,27 @@ def test_functional_image_correctness(self, angle, translate, scale, shear, cent
)

mae = (actual.float() - expected.float()).abs().mean()
assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8
if make_input is make_image_cvcuda:
# CV-CUDA nearest interpolation does not follow same algorithm as PIL/torch
assert mae < 255 if interpolation is transforms.InterpolationMode.NEAREST else 1, f"mae: {mae}"
else:
assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8, f"mae: {mae}"
Comment on lines +1606 to +1610
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enum comparison is preferred to use "==".
Also see comment here

The threshold number might need more investigations. @NicolasHug


@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
@pytest.mark.parametrize(
"interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
)
@pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
@pytest.mark.parametrize("seed", list(range(5)))
def test_transform_image_correctness(self, center, interpolation, fill, seed):
image = make_image(dtype=torch.uint8, device="cpu")
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
def test_transform_image_correctness(self, center, interpolation, fill, seed, make_input):
image = make_input(dtype=torch.uint8, device="cpu")

fill = adapt_fill(fill, dtype=torch.uint8)

Expand All @@ -1600,11 +1634,20 @@ def test_transform_image_correctness(self, center, interpolation, fill, seed):
torch.manual_seed(seed)
actual = transform(image)

if make_input is make_image_cvcuda:
actual = F.cvcuda_to_tensor(actual)[0].cpu()
image = F.cvcuda_to_tensor(image)[0].cpu()

torch.manual_seed(seed)
expected = F.to_image(transform(F.to_pil_image(image)))

mae = (actual.float() - expected.float()).abs().mean()
assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8
mae = (actual.float() - expected.float()).abs().mean()
if make_input is make_image_cvcuda:
# CV-CUDA nearest interpolation does not follow same algorithm as PIL/torch
assert mae < 255 if interpolation is transforms.InterpolationMode.NEAREST else 1, f"mae: {mae}"
else:
assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8, f"mae: {mae}"

def _compute_affine_matrix(self, *, angle, translate, scale, shear, center):
rot = math.radians(angle)
Expand Down
2 changes: 2 additions & 0 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,8 @@ class RandomAffine(Transform):

_v1_transform_cls = _transforms.RandomAffine

_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def __init__(
self,
degrees: Union[numbers.Number, Sequence],
Expand Down
5 changes: 3 additions & 2 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor


def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]:
Expand Down Expand Up @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor))
}
if not chws:
raise TypeError("No image or video was found in the sample")
Expand All @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
tv_tensors.Mask,
tv_tensors.BoundingBoxes,
tv_tensors.KeyPoints,
_is_cvcuda_tensor,
),
)
}
Expand Down
55 changes: 55 additions & 0 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Sequence
from typing import Any, Optional, TYPE_CHECKING, Union

import numpy as np
import PIL.Image
import torch
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
Expand All @@ -28,6 +29,7 @@

from ._utils import (
_FillTypeJIT,
_get_cvcuda_interp,
_get_kernel,
_import_cvcuda,
_is_cvcuda_available,
Expand Down Expand Up @@ -1331,6 +1333,59 @@ def affine_video(
)


def _affine_image_cvcuda(
image: "cvcuda.Tensor",
angle: Union[int, float],
translate: list[float],
scale: float,
shear: list[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: _FillTypeJIT = None,
center: Optional[list[float]] = None,
) -> "cvcuda.Tensor":
cvcuda = _import_cvcuda()

interpolation = _check_interpolation(interpolation)
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

height, width, num_channels = image.shape[1:]

# Determine the actual center point (cx, cy)
# torchvision uses image center by default, cvcuda transforms around upper-left (0,0)
# Unlike the tensor version which uses normalized coordinates centered at image center,
# CV-CUDA uses absolute pixel coordinates, so we pass actual center to _get_inverse_affine_matrix
if center is None:
cx, cy = width / 2.0, height / 2.0
else:
cx, cy = float(center[0]), float(center[1])

translate_f = [float(t) for t in translate]
matrix = _get_inverse_affine_matrix([cx, cy], angle, translate_f, scale, shear)

interp = _get_cvcuda_interp(interpolation)

xform = np.array([[matrix[0], matrix[1], matrix[2]], [matrix[3], matrix[4], matrix[5]]], dtype=np.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can rewrite this to np.array(matrix, dtype=np.float32).reshape(2, 3) for simplicity.


if fill is None:
border_value = np.zeros(num_channels, dtype=np.float32)
elif isinstance(fill, (int, float)):
border_value = np.full(num_channels, fill, dtype=np.float32)
else:
border_value = np.array(fill, dtype=np.float32)[:num_channels]

return cvcuda.warp_affine(
image,
xform,
flags=interp | cvcuda.Interp.WARP_INVERSE_MAP,
border_mode=cvcuda.Border.CONSTANT,
border_value=border_value,
)


if CVCUDA_AVAILABLE:
_register_kernel_internal(affine, _import_cvcuda().Tensor)(_affine_image_cvcuda)


def rotate(
inpt: torch.Tensor,
angle: float,
Expand Down
40 changes: 39 additions & 1 deletion torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import functools
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union

import torch
from torchvision import tv_tensors
from torchvision.transforms.functional import InterpolationMode

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]

_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[list[float]]
Expand Down Expand Up @@ -177,3 +181,37 @@ def _is_cvcuda_tensor(inpt: Any) -> bool:
return isinstance(inpt, cvcuda.Tensor)
except ImportError:
return False


_interpolation_mode_to_cvcuda_interp: dict[InterpolationMode | str | int, "cvcuda.Interp"] = {}


def _get_cvcuda_interp(interpolation: InterpolationMode | str | int) -> "cvcuda.Interp":
if not _interpolation_mode_to_cvcuda_interp:
cvcuda = _import_cvcuda()
_interpolation_mode_to_cvcuda_interp[InterpolationMode.NEAREST] = cvcuda.Interp.NEAREST
_interpolation_mode_to_cvcuda_interp[InterpolationMode.NEAREST_EXACT] = cvcuda.Interp.NEAREST
_interpolation_mode_to_cvcuda_interp[InterpolationMode.BILINEAR] = cvcuda.Interp.LINEAR
_interpolation_mode_to_cvcuda_interp[InterpolationMode.BICUBIC] = cvcuda.Interp.CUBIC
_interpolation_mode_to_cvcuda_interp[InterpolationMode.BOX] = cvcuda.Interp.BOX
_interpolation_mode_to_cvcuda_interp[InterpolationMode.HAMMING] = cvcuda.Interp.HAMMING
_interpolation_mode_to_cvcuda_interp[InterpolationMode.LANCZOS] = cvcuda.Interp.LANCZOS
_interpolation_mode_to_cvcuda_interp["nearest"] = cvcuda.Interp.NEAREST
_interpolation_mode_to_cvcuda_interp["nearest-exact"] = cvcuda.Interp.NEAREST
_interpolation_mode_to_cvcuda_interp["bilinear"] = cvcuda.Interp.LINEAR
_interpolation_mode_to_cvcuda_interp["bicubic"] = cvcuda.Interp.CUBIC
_interpolation_mode_to_cvcuda_interp["box"] = cvcuda.Interp.BOX
_interpolation_mode_to_cvcuda_interp["hamming"] = cvcuda.Interp.HAMMING
_interpolation_mode_to_cvcuda_interp["lanczos"] = cvcuda.Interp.LANCZOS
_interpolation_mode_to_cvcuda_interp[0] = cvcuda.Interp.NEAREST
_interpolation_mode_to_cvcuda_interp[2] = cvcuda.Interp.LINEAR
_interpolation_mode_to_cvcuda_interp[3] = cvcuda.Interp.CUBIC
_interpolation_mode_to_cvcuda_interp[4] = cvcuda.Interp.BOX
_interpolation_mode_to_cvcuda_interp[5] = cvcuda.Interp.HAMMING
_interpolation_mode_to_cvcuda_interp[1] = cvcuda.Interp.LANCZOS

interp = _interpolation_mode_to_cvcuda_interp.get(interpolation)
if interp is None:
raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA")

return interp
Loading