From 7e437a6a098119a4aff49819a6b66a2dc2826ae5 Mon Sep 17 00:00:00 2001 From: Marek Dabek Date: Wed, 25 Feb 2026 12:55:41 +0100 Subject: [PATCH 1/7] Torchvision API Initial commit including infrastructure, resize and flip operators and unit tests Signed-off-by: Marek Dabek --- dali/python/nvidia/dali/_conditionals.py | 2 +- .../dali/experimental/torchvision/__init__.py | 26 ++ .../experimental/torchvision/v2/compose.py | 358 ++++++++++++++++++ .../dali/experimental/torchvision/v2/flips.py | 82 ++++ .../torchvision/v2/functional/__init__.py | 22 ++ .../torchvision/v2/functional/flips.py | 39 ++ .../torchvision/v2/functional/resize.py | 68 ++++ .../experimental/torchvision/v2/operator.py | 300 +++++++++++++++ .../experimental/torchvision/v2/resize.py | 182 +++++++++ .../experimental/torchvision/v2/tensor.py | 32 ++ .../python/torchvision/test_tv_compose.py | 89 +++++ dali/test/python/torchvision/test_tv_flips.py | 71 ++++ .../test/python/torchvision/test_tv_resize.py | 220 +++++++++++ 13 files changed, 1490 insertions(+), 1 deletion(-) create mode 100644 dali/python/nvidia/dali/experimental/torchvision/__init__.py create mode 100644 dali/python/nvidia/dali/experimental/torchvision/v2/compose.py create mode 100644 dali/python/nvidia/dali/experimental/torchvision/v2/flips.py create mode 100644 dali/python/nvidia/dali/experimental/torchvision/v2/functional/__init__.py create mode 100644 dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py create mode 100644 dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py create mode 100644 dali/python/nvidia/dali/experimental/torchvision/v2/operator.py create mode 100644 dali/python/nvidia/dali/experimental/torchvision/v2/resize.py create mode 100644 dali/python/nvidia/dali/experimental/torchvision/v2/tensor.py create mode 100644 dali/test/python/torchvision/test_tv_compose.py create mode 100644 dali/test/python/torchvision/test_tv_flips.py create mode 100644 dali/test/python/torchvision/test_tv_resize.py diff --git a/dali/python/nvidia/dali/_conditionals.py b/dali/python/nvidia/dali/_conditionals.py index d3300e5b81f..6204c9bb09d 100644 --- a/dali/python/nvidia/dali/_conditionals.py +++ b/dali/python/nvidia/dali/_conditionals.py @@ -711,6 +711,6 @@ def lazy_or(self, a_value, b): _autograph.initialize_autograph( _OVERLOADS, - convert_modules=["nvidia.dali.auto_aug"], + convert_modules=["nvidia.dali.auto_aug", "nvidia.dali.experimental.torchvision"], do_not_convert_modules=["nvidia.dali._autograph", "nvidia.dali"], ) diff --git a/dali/python/nvidia/dali/experimental/torchvision/__init__.py b/dali/python/nvidia/dali/experimental/torchvision/__init__.py new file mode 100644 index 00000000000..31d1c4729b7 --- /dev/null +++ b/dali/python/nvidia/dali/experimental/torchvision/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .v2.compose import Compose +from .v2.flips import RandomHorizontalFlip, RandomVerticalFlip +from .v2.resize import Resize +from .v2.tensor import ToTensor + +__all__ = [ + "Compose", + "RandomHorizontalFlip", + "RandomVerticalFlip", + "Resize", + "ToTensor", +] diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py b/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py new file mode 100644 index 00000000000..9ee06378aa1 --- /dev/null +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py @@ -0,0 +1,358 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Sequence, Callable + +import nvidia.dali.fn as fn +from nvidia.dali.pipeline import pipeline_def +from nvidia.dali.data_node import DataNode as _DataNode +from nvidia.dali.backend import TensorListCPU, TensorListGPU + +from .tensor import ToTensor +from .operator import VerificationTensorOrImage + +import numpy as np +import multiprocessing +from PIL import Image +import torch + +DEFAULT_BATCH_SIZE = 16 +DEFAULT_NUM_THREADS = multiprocessing.cpu_count() // 2 + + +def _to_torch_tensor(tensor_or_tl: TensorListGPU | TensorListCPU) -> torch.Tensor: + if isinstance(tensor_or_tl, (TensorListGPU, TensorListCPU)): + dali_tensor = tensor_or_tl.as_tensor() + else: + dali_tensor = tensor_or_tl + + return torch.from_dlpack(dali_tensor) + + +def to_torch_tensor(tensor_or_tl: tuple | TensorListGPU | TensorListCPU) -> torch.Tensor: + """ + Converts a DALI tensor or tensor list to a PyTorch tensor. + + Parameters + ---------- + tensor_or_tl : tuple, TensorListGPU, TensorListCPU + DALI tensor or tensor list. + """ + if isinstance(tensor_or_tl, tuple) and len(tensor_or_tl) > 1: + tl = [] + for elem in tensor_or_tl: + tl.append(_to_torch_tensor(elem)) + return tuple(tl) + else: + if len(tensor_or_tl) == 1: + tensor_or_tl = tensor_or_tl[0] + return _to_torch_tensor(tensor_or_tl) + + +@pipeline_def(enable_conditionals=True, exec_dynamic=True, prefetch_queue_depth=1) +def _pipeline_function(op_list, layout="HWC"): + """ + Builds a DALI pipeline from a list of operators. + + Parameters + ---------- + op_list : list + List of DALI operators. + layout : str + Layout of the data. + """ + input_node = fn.external_source(name="input_data", no_copy=True, layout=layout) + for op in op_list: + if isinstance(op, ToTensor) and op != op_list[-1]: + raise NotImplementedError("ToTensor can only be the last operation in the pipeline") + input_node = op(input_node) + return input_node + + +class PipelineLayouted: + """Base class for pipeline layouts. + + This class is a base class for DALI pipelines with a specific layout. It is used to handle + the layout of the data. + Single DALI Pipeline can only use one layout at a time. + + Parameters + ---------- + op_list : list + List of DALI operators. + layout : str + Layout of the data. + batch_size : int, optional, default = DEFAULT_BATCH_SIZE + Batch size. + num_threads : int, optional, default = DEFAULT_NUM_THREADS + Number of threads. + **dali_pipeline_kwargs + Additional keyword arguments for the DALI pipeline. + """ + + def __init__( + self, + op_list: List[Callable[..., Sequence[_DataNode] | _DataNode]], + layout: str, + batch_size: int = DEFAULT_BATCH_SIZE, + num_threads: int = DEFAULT_NUM_THREADS, + **dali_pipeline_kwargs, + ): + self.convert_to_tensor = True if isinstance(op_list[-1], ToTensor) else False + self.pipe = _pipeline_function( + op_list, + layout=layout, + batch_size=batch_size, + num_threads=num_threads, + *dali_pipeline_kwargs, + ) + + def run(self, data_input): + output = None + + stream = torch.cuda.Stream(0) + with torch.cuda.stream(stream): + output = self.pipe.run(stream, input_data=data_input) + + if output is None: + return output + + output = to_torch_tensor(output) + # ToTensor + if self.convert_to_tensor: + if output.shape[-4] > 1: + raise NotImplementedError("ToTensor does not currently work for batches") + + return output + + def get_layout(self) -> str: ... + + def get_channel_reverse_idx(self) -> int: ... + + def is_conversion_to_tensor(self) -> bool: + return self.convert_to_tensor + + +class PipelineHWC(PipelineLayouted): + """Handles ``PIL.Image`` in HWC format. + + This class prepares data to be passed to a DALI pipeline, runs the pipeline and converts + the output to a ``PIL.Image``. + + Parameters + ---------- + op_list : list + List of DALI operators. + batch_size : int, optional, default = DEFAULT_BATCH_SIZE + Batch size. + num_threads : int, optional, default = DEFAULT_NUM_THREADS + Number of threads. + **dali_pipeline_kwargs + Additional keyword arguments for the DALI pipeline. + """ + + def __init__( + self, + op_list: List[Callable[..., Sequence[_DataNode] | _DataNode]], + batch_size: int = DEFAULT_BATCH_SIZE, + num_threads: int = DEFAULT_NUM_THREADS, + **dali_pipeline_kwargs, + ): + super().__init__( + op_list, + layout="HWC", + batch_size=batch_size, + num_threads=num_threads, + *dali_pipeline_kwargs, + ) + + def _convert_tensor_to_image(self, in_tensor: torch.Tensor): + + channels = self.get_channel_reverse_idx() + + # TODO: consider when to convert to PIL.Image - e.g. if it make sense for channels < 3 + if in_tensor.shape[channels] == 1: + mode = "L" + in_tensor = in_tensor.squeeze(-1) + elif in_tensor.shape[channels] == 3: + mode = "RGB" + elif in_tensor.shape[channels] == 4: + mode = "RGBA" + else: + raise ValueError( + f"Unsupported number of channels: {in_tensor.shape[channels]}. Should be 1 or 3." + ) + # We need to convert tensor to CPU, PIL does not support CUDA tensors + return Image.fromarray(in_tensor.cpu().numpy(), mode=mode) + + def run(self, data_input): + if isinstance(data_input, Image.Image): + _input = torch.as_tensor(np.array(data_input, copy=True)).unsqueeze(0) + if data_input.mode == "L": + _input = _input.unsqueeze(-1) + else: + raise ValueError( + "HWC layout is currently supported for PIL Images only.\ + Please check if samples have the same format." + ) + + output = super().run(_input) + + if self.is_conversion_to_tensor(): + return output + + if isinstance(output, tuple): + output = self._convert_tensor_to_image(output[0]) + else: + # batches + if output.shape[0] > 1: + output_list = [] + for i in range(output.shape[0]): + output_list.append(self._convert_tensor_to_image(output[i])) + output = output_list + else: + output = self._convert_tensor_to_image(output[0]) + + return output + + def get_layout(self) -> str: + return "HWC" + + def get_channel_reverse_idx(self) -> int: + return -1 + + +class PipelineCHW(PipelineLayouted): + """Handles ``torch.Tensors`` in CHW format. + + This class prepares data to be passed to a DALI pipeline and runs the pipeline, converting + the output to a ``torch.Tensor``. + + Parameters + ---------- + op_list : list + List of DALI operators. + batch_size : int, optional, default = DEFAULT_BATCH_SIZE + Batch size. + num_threads : int, optional, default = DEFAULT_NUM_THREADS + Number of threads. + **dali_pipeline_kwargs + Additional keyword arguments for the DALI pipeline. + """ + + def __init__( + self, + op_list: List[Callable[..., Sequence[_DataNode] | _DataNode]], + batch_size: int = DEFAULT_BATCH_SIZE, + num_threads: int = DEFAULT_NUM_THREADS, + **dali_pipeline_kwargs, + ): + super().__init__( + op_list, + layout="CHW", + batch_size=batch_size, + num_threads=num_threads, + *dali_pipeline_kwargs, + ) + + def run(self, data_input): + if isinstance(data_input, torch.Tensor): + _input = data_input + if data_input.ndim == 3: + # DALI requires batch size to be present + _input = data_input.unsqueeze(0) + else: + raise ValueError( + "CHW layout is currently supported for torch.Tensor only.\ + Please check if samples have the same format." + ) + output = super().run(_input) + + if data_input.ndim == 3: + # DALI requires batch size to be present + output = output.squeeze(0) + return output + + def get_layout(self) -> str: + return "CHW" + + def get_channel_reverse_idx(self) -> int: + return -3 + + +class Compose: + """ + Composes transforms together in a single pipeline + + This class chains multiple DALI operations in a sequential manner, similar to + ``torchvision.transforms.Compose``. The ``Compose`` class implements a callable which runs + the pipeline. + + Parameters + ---------- + op_list : list + List of DALI operators. + batch_size : int, optional, default = DEFAULT_BATCH_SIZE + Batch size. + num_threads : int, optional, default = DEFAULT_NUM_THREADS + Number of threads. + **dali_pipeline_kwargs + Additional keyword arguments for the DALI pipeline. + """ + + def __init__( + self, + op_list: List[Callable[..., Sequence[_DataNode] | _DataNode]], + batch_size: int = DEFAULT_BATCH_SIZE, + num_threads: int = DEFAULT_NUM_THREADS, + **dali_pipeline_kwargs, + ): + self.op_list = op_list + self.batch_size = batch_size + self.num_threads = num_threads + self.active_pipeline = None + self.dali_pipeline_kwargs = dali_pipeline_kwargs + + def _build_pipeline(self, data_input): + if isinstance(data_input, Image.Image): + self.active_pipeline = PipelineHWC( + self.op_list, self.batch_size, self.num_threads, *self.dali_pipeline_kwargs + ) + elif isinstance(data_input, torch.Tensor): + self.active_pipeline = PipelineCHW( + self.op_list, self.batch_size, self.num_threads, *self.dali_pipeline_kwargs + ) + else: + raise ValueError("Currently only PILImages and torch.Tensors are supported") + + def __call__(self, data_input): + """ + Runs the pipeline + + The ``Pipeline`` class builds a graph based on the operations list passed in + the constructor. Next, whenever the ``Compose`` object is called it starts the pipeline + and returns results. + + Parameters + ---------- + data_input: Tensor or PIL Image + In case of PIL image it will be converted to tensor before sending to pipeline + """ + + VerificationTensorOrImage.verify(data_input) + + if self.active_pipeline is None: + self._build_pipeline(data_input) + + return self.active_pipeline.run(data_input=data_input) diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/flips.py b/dali/python/nvidia/dali/experimental/torchvision/v2/flips.py new file mode 100644 index 00000000000..adbe02ae81a --- /dev/null +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/flips.py @@ -0,0 +1,82 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal +from .operator import Operator +import nvidia.dali.fn as fn + + +class RandomFlip(Operator): + """ + Randomly flips the given image randomly with a given probability. + + Parameters + ---------- + p : float + Probability of the image being flipped. Default value is 0.5 + horizontal : int + Flip the horizontal dimension. + device : Literal["cpu", "gpu"], optional, default = "cpu" + Device to use for the flip. Can be ``"cpu"`` or ``"gpu"``. + """ + + def __init__(self, p: float = 0.5, horizontal: int = 1, device: Literal["cpu", "gpu"] = "cpu"): + super().__init__(device=device) + self.prob = p + self.device = device + self.horizontal = horizontal + + def _kernel(self, data_input): + if self.horizontal: + data_input = fn.flip( + data_input, horizontal=fn.random.coin_flip(probability=self.prob), vertical=0 + ) + else: + data_input = fn.flip( + data_input, horizontal=0, vertical=fn.random.coin_flip(probability=self.prob) + ) + + return data_input + + +class RandomHorizontalFlip(RandomFlip): + """ + Randomly horizontally flips the given image randomly with a given probability. + + Parameters + ---------- + p : float + Probability of the image being flipped. Default value is 0.5 + device : Literal["cpu", "gpu"], optional, default = "cpu" + Device to use for the flip. Can be ``"cpu"`` or ``"gpu"``. + """ + + def __init__(self, p: float = 0.5, device: Literal["cpu", "gpu"] = "cpu"): + super().__init__(p, True, device) + + +class RandomVerticalFlip(RandomFlip): + """ + Randomly vertically flips the given image randomly with a given probability. + + Parameters + ---------- + p : float + Probability of the image being flipped. Default value is 0.5 + device : Literal["cpu", "gpu"], optional, default = "cpu" + Device to use for the flip. Can be ``"cpu"`` or ``"gpu"``. + """ + + def __init__(self, p: float = 0.5, device: Literal["cpu", "gpu"] = "cpu"): + super().__init__(p, False, device) diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/__init__.py b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/__init__.py new file mode 100644 index 00000000000..fedd6bc05bb --- /dev/null +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .flips import horizontal_flip, vertical_flip +from .resize import resize + +__all__ = [ + "horizontal_flip", + "resize", + "vertical_flip", +] diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py new file mode 100644 index 00000000000..f60627f945a --- /dev/null +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py @@ -0,0 +1,39 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import nvidia.dali.experimental.dynamic as ndd + +import sys + +sys.path.append("..") +from ..operator import adjust_input # noqa: E402 + + +@adjust_input +def horizontal_flip(inpt: ndd.Tensor) -> ndd.Tensor: + """ + Horizontally flips the given tensor. + Refer to ``HorizontalFlip`` for more details. + """ + return ndd.flip(inpt, horizontal=1, vertical=0) + + +@adjust_input +def vertical_flip(inpt: ndd.Tensor) -> ndd.Tensor: + """ + Vertically flips the given tensor. + Refer to ``VerticalFlip`` for more details. + """ + return ndd.flip(inpt, horizontal=0, vertical=1) diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py new file mode 100644 index 00000000000..de689549330 --- /dev/null +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py @@ -0,0 +1,68 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List, Literal +from torch import Tensor +import nvidia.dali.experimental.dynamic as ndd +from torchvision.transforms import InterpolationMode + +import sys + +sys.path.append("..") +from ..operator import adjust_input # noqa: E402 +from ..resize import Resize # noqa: E402 + + +@adjust_input +def resize( + img: Tensor, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = True, + device: Literal["cpu", "gpu"] = "cpu", +) -> Tensor: + """ + Please refer to the ``Resize`` operator for more details. + """ + Resize.verify_args( + size=size, max_size=max_size, interpolation=interpolation, antialias=antialias + ) + + effective_size, mode = Resize.infer_effective_size(size, max_size) + interpolation = Resize.interpolation_modes[interpolation] + + target_h, target_w = Resize.calculate_target_size( + img.shape, effective_size, max_size, size is None + ) + + # Shorter edge limited by max size + if mode == "resize_shorter": + return ndd.resize( + img, + device=device, + resize_shorter=target_h, + max_size=max_size, + interp_type=interpolation, + antialias=antialias, + ) + + return ndd.resize( + img, + device=device, + size=(target_h, target_w), + mode=mode, + interp_type=interpolation, + antialias=antialias, + ) diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/operator.py b/dali/python/nvidia/dali/experimental/torchvision/v2/operator.py new file mode 100644 index 00000000000..b84e76149a9 --- /dev/null +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/operator.py @@ -0,0 +1,300 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Sequence, Literal + +from PIL import Image +import torch +import numpy as np + +import nvidia.dali.experimental.dynamic as ndd + + +class DataVerificationRule(ABC): + """ + Abstract base class for data verification rules + + Implement ``verify`` method in a child class raising an exception in case of failed verification + """ + + @classmethod + @abstractmethod + def verify(cls, data) -> None: + pass + + +class ArgumentVerificationRule(ABC): + """ + Abstract base class for input verification rules + + Implement ``verify`` method in a child class raising an exception in case of failed verification + """ + + @classmethod + @abstractmethod + def verify(cls, **kwargs) -> None: + pass + + +class VerificationIsTensor(DataVerificationRule): + """ + Verify if the data is a ``torch.Tensor``. + + Parameters + ---------- + data : any + Data to verify. Should be a ``torch.Tensor``. + """ + + @classmethod + def verify(cls, data): + if not isinstance(data, (torch.Tensor)): + raise TypeError(f"Data should be Tensor. Got {type(data)}") + + +class VerificationTensorOrImage(DataVerificationRule): + """ + Verify if the data is a ``torch.Tensor`` or ``PIL.Image``. + + Parameters + ---------- + data : any + Data to verify. Should be a ``torch.Tensor`` or ``PIL.Image``. + """ + + @classmethod + def verify(cls, data): + if not isinstance(data, (Image.Image, torch.Tensor)): + raise TypeError(f"inpt should be Tensor or PIL Image. Got {type(data)}") + + +class VerificationChannelCount(DataVerificationRule): + """ + Verify the number of channels for the input data. + + Parameters + ---------- + data : any + Data to verify in CHW format. + """ + + CHANNELS = [1, 2, 3, 4] + + @classmethod + def verify(cls, data): + if ( + isinstance(data, torch.Tensor) + and data.shape[-3] not in VerificationChannelCount.CHANNELS + ): + raise ValueError( + f"Input should be in CHW if Tensor. \ + Supported channels: {VerificationChannelCount.CHANNELS} is {data.shape[-3]}" + ) + + +class VerifyIfPositive(ArgumentVerificationRule): + """ + Verify if the value is positive. + + Parameters + ---------- + values : any + Value to verify. Should be a positive number. + """ + + @classmethod + def verify(cls, *, values, name, **_) -> None: + if isinstance(values, (int, float)) and values <= 0: + raise ValueError(f"Value {name} must be positive, got {values}") + elif isinstance(values, (list, tuple)) and any(k <= 0 for k in values): + raise ValueError(f"Values {name} should be positive number, got {values}") + + +class VerifyIfOrderedPair(ArgumentVerificationRule): + """ + Verify if the value is an ordered pair. + + Parameters + ---------- + values : any + Value to verify. Should be an ordered pair. + """ + + @classmethod + def verify(cls, *, values, name, **_) -> None: + if isinstance(values, (list, tuple)) and len(values) == 2 and values[0] > values[1]: + raise ValueError(f"Values {name} should be ordered, got {values}") + + +class VerificationSize(ArgumentVerificationRule): + """ + Verify if the value is an integer or a sequence of length 1 or 2. + + Parameters + ---------- + size : any + Value to verify. Should be an integer or a sequence of length 1 or 2. + """ + + @classmethod + def verify(cls, *, size, **_) -> None: + if not isinstance(size, (int, list, tuple)): + raise TypeError(f"Size must be int or sequence, got {type(size)}") + elif isinstance(size, (list, tuple)) and len(size) > 2: + raise ValueError(f"Size sequence must have length 1 or 2, got {len(size)}") + VerifyIfPositive.verify(values=size, name="size") + + +class Operator(ABC): + """ + Abstract base class for operator specification + + Implement _kernel for algorithm specific processing + + ``arg_rules`` - a sequence of verification rules for algorithm's arguments. + ``input_rules`` - a sequence of verification rules for algorithm's input data. + ``preprocess_data`` - a function to preprocess the input data. + + Parameters + ---------- + device : Literal["cpu", "gpu"], optional, default = "cpu" + Device to use for the operator. Can be ``"cpu"`` or ``"gpu"``. + **kwargs + Additional keyword arguments for the operator. + """ + + arg_rules: Sequence[ArgumentVerificationRule] = [] + input_rules: Sequence[DataVerificationRule] = [] + preprocess_data = None + + @classmethod + def verify_args(cls, **kwargs): + for rule in cls.arg_rules: + rule.verify(**kwargs) + + @classmethod + def verify_data(cls, data_input): + for rule in cls.input_rules: + rule.verify(data_input) + + def __init__(self, device: Literal["cpu", "gpu"] = "cpu", **kwargs): + self.device = device + type(self).verify_args(**kwargs) + + @abstractmethod + def _kernel(self, data_input): + """ + Algorithm's processing + """ + pass + + def __call__(self, data_input): + + if self.device == "gpu": + data_input = data_input.gpu() + + if type(self).preprocess_data: + data_input = type(self).preprocess_data(data_input) + + output = self._kernel(data_input) + + return output + + +def adjust_input(func): + """ + This decorator transforms the 1st argument of a function to internal DALI representation + according to the following rules: + - ``PIL.Image`` -> ``ndd.Tensor(layout = "HWC")`` + - ``torch.Tensor``: + - ``ndim == 3`` -> ``ndd.Tensor(layout = "CHW")`` + - ``ndim > 3`` -> ``ndd.Batch(layout = "CHW")`` + + Note: When new input types are supported this function will be extended. + """ + + def transform_input(inpt) -> ndd.Tensor | ndd.Batch: + """ + Transforms supported inputs to either DALI tensor or batch + The following conversion rules apply: + - PIL Image -> ndd.Tensor(layout="HWC"), depending on the number of channels it outputs: + L, RGB, or RGBA mode + - torch.Tensor: + ndim==3 -> ndd.Tensor(layout = "CHW"), + ndim>3 -> ndd.Batch(layout="NCHW") + """ + mode = "RGB" + if isinstance(inpt, Image.Image): + _input = ndd.Tensor(np.array(inpt, copy=True), layout="HWC") + if _input.shape[-1] == 1: + mode = "L" + elif _input.shape[-1] == 4: + mode = "RGBA" + elif isinstance(inpt, torch.Tensor): + if inpt.ndim == 3: + _input = ndd.Tensor(inpt, layout="CHW") + elif inpt.ndim > 3: + # The following should work, bug: https://jirasw.nvidia.com/browse/DALI-4566 + # _input = ndd.as_batch(inpt, layout="NCHW") + # WAR: + _input = ndd.as_batch(ndd.as_tensor(inpt), layout="CHW") + else: + raise TypeError(f"Tensor has < 3 dimensions: {inpt.ndim} / {inpt.shape}") + else: + raise TypeError(f"Data type: {type(inpt)} is not supported") + + return _input, mode + + def adjust_output( + output: ndd.Tensor | ndd.Batch, inpt, mode: str = "RGB" + ) -> Image.Image | torch.Tensor: + """ + Adjusts output to match the original input type or operator's result + Depending on the inpt: + - PIL Image: output ndd.Tensor -> PIL Image with applicable mode ("L", "RGB", "RGBA") + - torch.Tensor: + ndd.Batch -> torch.Tensor with leading dimension as a number of samples in batch + ndd.Tensor -> torch.Tensor + """ + if isinstance(inpt, Image.Image): + if output.shape[-1] == 1: + output = np.asarray(output).squeeze(2) + mode = "L" + return Image.fromarray(np.asarray(output), mode=mode) + elif isinstance(inpt, torch.Tensor): + # For input being torch.Tensor only ndd.Batch or ndd.Tensor is allowed as output + if isinstance(output, ndd.Batch): + output = ndd.as_tensor(output) + elif not isinstance(output, ndd.Tensor): + raise TypeError(f"Invalid output type: {type(output)}") + + # This is WAR for DLPpack not supporting pinned memory + if output.device.device_type == "cpu": + output = np.asarray(output) + + return torch.as_tensor(output) + else: + return output + + def inner_function(inpt, *args, **kwargs): + + _input, mode = transform_input(inpt) + output = func(_input, *args, **kwargs) + + output = output.evaluate() + + return adjust_output(output, inpt, mode) + + return inner_function diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py b/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py new file mode 100644 index 00000000000..5c1478e959e --- /dev/null +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py @@ -0,0 +1,182 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Literal + +from .operator import Operator, ArgumentVerificationRule + +import nvidia.dali as dali +import nvidia.dali.fn as fn +from nvidia.dali.types import DALIInterpType + +from torchvision.transforms import InterpolationMode +import numpy as np + + +class VerificationSize(ArgumentVerificationRule): + @classmethod + def verify(cls, *, size, max_size, interpolation, **_): + if size is not None and not isinstance(size, int) and not isinstance(size, (tuple, list)): + raise ValueError( + "Invalid combination: size must be int, None, or sequence of two ints. " + "max_size only applies when size is int or None." + ) + if size is None and max_size is None: + raise ValueError("Must provide max_size if size is None.") + if size is not None and max_size is not None and np.min(size) > max_size: + raise ValueError("max_size should not be smaller than the actual size") + if isinstance(size, (tuple, list)) and len(size) == 2 and max_size is not None: + raise ValueError( + "max_size should only be passed if size specifies the length of the smaller \ + edge, i.e. size should be an int" + ) + if interpolation not in Resize.interpolation_modes.keys(): + raise ValueError(f"Interpolation {type(interpolation)} is not supported") + + +class Resize(Operator): + """ + Resize the input image to the given size + If the image is torch Tensor, it is expected to have […, H, W] shape, where … means a maximum + of two leading dimensions + + Parameters + ---------- + size:sequence or int + Desired output size. If size is a sequence like (h, w), output size will be matched + to this. If size is an int, smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to (size * height / width, size). + interpolation : InterpolationMode or int + ``torchvision.transforms.InterpolationMode``. Default is InterpolationMode.BILINEAR. + If input is Tensor, only ``InterpolationMode.NEAREST``, + ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and + ``InterpolationMode.BICUBIC`` are supported. + max_size : int, optional + The maximum allowed for the longer edge of the resized image. If the longer edge of + the image is greater than max_size after being resized according to size, size will + be overruled so that the longer edge is equal to max_size. As a result, the smaller + edge may be shorter than size. This is only supported if size is an int. + antialias : bool, optional + Whether to apply antialiasing. If ``True``, antialiasing will be applied. If ``False``, + antialiasing will not be applied. + device : Literal["cpu", "gpu"], optional, default = "cpu" + Device to use for the resize. Can be ``"cpu"`` or ``"gpu"``. + """ + + # 'NEAREST', 'NEAREST_EXACT', 'BILINEAR', 'BICUBIC', 'BOX', 'HAMMING', 'LANCZOS' + interpolation_modes = { + InterpolationMode.NEAREST: DALIInterpType.INTERP_NN, + InterpolationMode.NEAREST_EXACT: DALIInterpType.INTERP_NN, # TODO + InterpolationMode.BILINEAR: DALIInterpType.INTERP_LINEAR, + InterpolationMode.BICUBIC: DALIInterpType.INTERP_CUBIC, + InterpolationMode.BOX: DALIInterpType.INTERP_LINEAR, # TODO: + InterpolationMode.HAMMING: DALIInterpType.INTERP_GAUSSIAN, # TODO: + InterpolationMode.LANCZOS: DALIInterpType.INTERP_LANCZOS3, + } + arg_rules = [VerificationSize] + + @classmethod + def infer_effective_size( + cls, + size: Optional[int | Sequence[int]], + max_size: Optional[int] = None, + ) -> int | Sequence[int]: + + mode = "default" + + if isinstance(size, (tuple, list)) and len(size) == 1: + size = size[0] + + if isinstance(size, int): + # If size is an int, smaller edge of the image will be matched to this number. + # If size is an int: if the longer edge of the image is greater than max_size + # after being resized according to size, size will be overruled so that the + # longer edge is equal to max_size. As a result, the smaller edge may be shorter + # than size. + mode = "resize_shorter" + + return ((size, size), mode) + + if size is None: + mode = "not_larger" + return ((max_size, max_size), mode) + + return size, mode + + @classmethod + def calculate_target_size( + cls, orig_size: Sequence[int], effective_size: Sequence[int], max_size: int, no_size: bool + ): + orig_h = orig_size[0] + orig_w = orig_size[1] + target_h = effective_size[0] + target_w = effective_size[1] + + # If size is None, then effective_size is max_size + if no_size: + if orig_h > orig_w: + target_w = (max_size * orig_w) / orig_h + else: + target_h = (max_size * orig_h) / orig_w + + return target_h, target_w + + def __init__( + self, + size: Optional[int | Sequence[int]], + interpolation: InterpolationMode | int = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = True, + device: Literal["cpu", "gpu"] = "cpu", + ): + + super().__init__( + device=device, + size=size, + max_size=max_size, + interpolation=interpolation, + ) + + self.size = size + self.max_size = max_size + self.interpolation = Resize.interpolation_modes[interpolation] + self.effective_size, self.mode = Resize.infer_effective_size(size, max_size) + self.antialias = antialias + + def _kernel(self, data_input): + """ + Performs the resize. The method infers the requested size in compliance + with ``torchvision.transforms.Resize`` documentation and applies DALI operator on the + ``data_input``. + """ + + target_h, target_w = Resize.calculate_target_size( + data_input.shape(), self.effective_size, self.max_size, self.size is None + ) + + # Shorter edge limited by max size + if self.mode == "resize_shorter": + return fn.resize( + data_input, device=self.device, resize_shorter=target_h, max_size=self.max_size + ) + + return fn.resize( + data_input, + device=self.device, + size=fn.stack( + fn.cast(target_h, dtype=dali.types.FLOAT), + fn.cast(target_w, dtype=dali.types.FLOAT), + ), + mode=self.mode, + ) diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/tensor.py b/dali/python/nvidia/dali/experimental/torchvision/v2/tensor.py new file mode 100644 index 00000000000..cc6514566dd --- /dev/null +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/tensor.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class ToTensor: + """ + Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor + of shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes + (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8 + + [DEPRECATED but used] + """ + + def __call__(self, data_input): + """ + Performs to tensor conversion it only converts to float, the remaining part is being done + in Compose.__call__ + """ + # TODO: if data_input.dtype==types.DALIDataType.UINT8: + data_input = data_input / 255.0 + return data_input diff --git a/dali/test/python/torchvision/test_tv_compose.py b/dali/test/python/torchvision/test_tv_compose.py new file mode 100644 index 00000000000..1b06434ea41 --- /dev/null +++ b/dali/test/python/torchvision/test_tv_compose.py @@ -0,0 +1,89 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from nvidia.dali.experimental.torchvision import Compose, RandomHorizontalFlip + +from nose_utils import assert_raises +import numpy as np +from PIL import Image +import torchvision.transforms as tv +import torch +import torchvision.transforms.v2 as transforms + + +def read_filepath(path): + return np.frombuffer(path.encode(), dtype=np.int8) + + +def make_test_tensor(shape=(5, 10, 10, 1)): + total = 1 + for s in shape: + total *= s + return torch.arange(total).reshape(shape).to(dtype=torch.uint8) + + +dali_extra = os.environ["DALI_EXTRA_PATH"] +jpeg = os.path.join(dali_extra, "db", "single", "jpeg") +jpeg_113 = os.path.join(jpeg, "113") +test_files = [ + os.path.join(jpeg_113, f) + for f in ["snail-4291306_1280.jpg", "snail-4345504_1280.jpg", "snail-4368154_1280.jpg"] +] +test_input_filenames = [read_filepath(fname) for fname in test_files] + + +def test_compose_tensor(): + test_tensor = make_test_tensor(shape=(5, 5, 5, 3)) + dali_pipeline = Compose([RandomHorizontalFlip(p=1.0)], batch_size=test_tensor.shape[0]) + dali_out = dali_pipeline(test_tensor) + tv_out = tv.RandomHorizontalFlip(p=1.0)(test_tensor) + + assert isinstance(dali_out, torch.Tensor) + assert torch.equal(dali_out, tv_out) + + +def test_compose_invalid_batch_tensor(): + test_tensor = make_test_tensor(shape=(5, 5, 5, 1)) + with assert_raises(RuntimeError): + dali_pipeline = Compose([RandomHorizontalFlip(p=1.0)], batch_size=1) + _ = dali_pipeline(test_tensor) + + +def test_compose_images(): + dali_transform = Compose([RandomHorizontalFlip(p=1.0)]) + tv_transform = tv.RandomHorizontalFlip(p=1.0) + + for fn in test_files: + img = Image.open(fn) + out_dali_img = dali_transform(img) + + assert isinstance(out_dali_img, Image.Image) + + tensor_dali_tv = transforms.functional.pil_to_tensor(out_dali_img) + tensor_tv = transforms.functional.pil_to_tensor(tv_transform(img)) + + assert tensor_dali_tv.shape == tensor_tv.shape + assert torch.equal(tensor_dali_tv, tensor_tv) + + +def test_compose_invalid_type_images(): + dali_transform = Compose([RandomHorizontalFlip(p=1.0)]) + + for fn in test_files: + img = Image.open(fn) + with assert_raises(TypeError): + out_dali_img = dali_transform([img, img, img]) + assert isinstance(out_dali_img, Image.Image) diff --git a/dali/test/python/torchvision/test_tv_flips.py b/dali/test/python/torchvision/test_tv_flips.py new file mode 100644 index 00000000000..39bc158b3f7 --- /dev/null +++ b/dali/test/python/torchvision/test_tv_flips.py @@ -0,0 +1,71 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torchvision.transforms.v2 as tv + +from nose2.tools import params +from nvidia.dali.experimental.torchvision import Compose, RandomHorizontalFlip, RandomVerticalFlip +from nvidia.dali.experimental.torchvision.v2.functional import horizontal_flip, vertical_flip + + +def make_test_tensor(shape=(1, 10, 10, 3)): + total = 1 + for s in shape: + total *= s + return torch.arange(total).reshape(shape) + + +@params("gpu", "cpu") +def test_horizontal_random_flip_probability(device): + img = make_test_tensor() + transform = Compose([RandomHorizontalFlip(p=1.0, device=device)]) # always flip + out = transform(img).cpu() + out_tv = tv.RandomHorizontalFlip(p=1.0)(img) + out_fn = horizontal_flip(img).cpu() + assert torch.equal(out, out_tv) + assert torch.equal(out_fn, out_tv) + + transform = Compose([RandomHorizontalFlip(p=0.0, device=device)]) # never flip + out = transform(img).cpu() + assert torch.equal(out, img) + + +@params("gpu", "cpu") +def test_vertical_random_flip_probability(device): + img = make_test_tensor() + transform = Compose([RandomVerticalFlip(p=1.0, device=device)]) # always flip + out = transform(img).cpu() + out_tv = tv.RandomVerticalFlip(p=1.0)(img) + out_fn = vertical_flip(img).cpu() + assert torch.equal(out, out_tv) + assert torch.equal(out, out_fn) + + transform = Compose([RandomVerticalFlip(p=0.0, device=device)]) # never flip + out = transform(img).cpu() + assert torch.equal(out, img) + + +def test_flip_preserves_shape(): + img = make_test_tensor((1, 15, 20, 3)) + hflip_pipeline = Compose([RandomHorizontalFlip(p=1.0)]) + hflip_fn = horizontal_flip(img).cpu() + hflip = hflip_pipeline(img) + vflip_pipeline = Compose([RandomVerticalFlip(p=1.0)]) + vflip_fn = vertical_flip(img).cpu() + vflip = vflip_pipeline(img) + assert hflip.shape == img.shape + assert vflip.shape == img.shape + assert hflip_fn.shape == img.shape + assert vflip_fn.shape == img.shape diff --git a/dali/test/python/torchvision/test_tv_resize.py b/dali/test/python/torchvision/test_tv_resize.py new file mode 100644 index 00000000000..d1f13832578 --- /dev/null +++ b/dali/test/python/torchvision/test_tv_resize.py @@ -0,0 +1,220 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Sequence + +import numpy as np +from nose2.tools import params, cartesian_params +from nose_utils import assert_raises +from PIL import Image +import torch +import torchvision.transforms.v2 as transforms +import torchvision.transforms.v2.functional as fn_tv + +from nvidia.dali.experimental.torchvision import Resize, Compose, ToTensor +import nvidia.dali.experimental.torchvision.v2.functional as fn_dali + + +def read_file(path): + return np.fromfile(path, dtype=np.uint8) + + +def read_filepath(path): + return np.frombuffer(path.encode(), dtype=np.int8) + + +dali_extra = os.environ["DALI_EXTRA_PATH"] +jpeg = os.path.join(dali_extra, "db", "single", "jpeg") +jpeg_113 = os.path.join(jpeg, "113") +test_files = [ + os.path.join(jpeg_113, f) + for f in ["snail-4291306_1280.jpg", "snail-4345504_1280.jpg", "snail-4368154_1280.jpg"] +] +test_input_filenames = [read_filepath(fname) for fname in test_files] + + +def build_resize_transform( + resize: int | Sequence[int], + max_size: int = None, + interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BILINEAR, + antialias: bool = False, +): + t = transforms.Compose( + [ + transforms.Resize( + size=resize, max_size=max_size, interpolation=interpolation, antialias=antialias + ), + ] + ) + td = Compose( + [ + Resize( + size=resize, max_size=max_size, interpolation=interpolation, antialias=antialias + ), + ] + ) + return t, td + + +def loop_images_test_no_build( + t: transforms.Resize, + td: Resize, + resize: int | Sequence[int], + max_size: int = None, + interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BILINEAR, + antialias: bool = False, +): + for fn in test_files: + img = Image.open(fn) + out_fn = transforms.functional.pil_to_tensor( + fn_tv.resize( + img, + size=resize, + max_size=max_size, + interpolation=interpolation, + antialias=antialias, + ) + ) + out_dali_fn = transforms.functional.pil_to_tensor( + fn_dali.resize( + img, + size=resize, + max_size=max_size, + interpolation=interpolation, + antialias=antialias, + ) + ) + + out_tv = transforms.functional.pil_to_tensor(t(img)).unsqueeze(0).permute(0, 2, 3, 1) + out_dali_tv = transforms.functional.pil_to_tensor(td(img)).unsqueeze(0).permute(0, 2, 3, 1) + tv_shape_lower = torch.Size([out_tv.shape[1] - 1, out_tv.shape[2] - 1]) + tv_shape_upper = torch.Size([out_tv.shape[1] + 1, out_tv.shape[2] + 1]) + + tv_fn_shape_lower = torch.Size([out_fn.shape[1] - 1, out_fn.shape[2] - 1]) + tv_fn_shape_upper = torch.Size([out_fn.shape[1] + 1, out_fn.shape[2] + 1]) + + assert ( + tv_shape_lower[0] <= out_dali_tv.shape[1] <= tv_shape_upper[0] + ), f"Should be:{out_tv.shape} is:{out_dali_tv.shape}" + assert ( + tv_shape_lower[1] <= out_dali_tv.shape[2] <= tv_shape_upper[1] + ), f"Should be:{out_tv.shape} is:{out_dali_tv.shape}" + + assert ( + tv_fn_shape_lower[0] <= out_dali_fn.shape[1] <= tv_fn_shape_upper[0] + ), f"Should be:{out_tv.shape} is:{out_dali_fn.shape}" + assert ( + tv_fn_shape_lower[1] <= out_dali_fn.shape[2] <= tv_fn_shape_upper[1] + ), f"Should be:{out_tv.shape} is:{out_dali_fn.shape}" + # assert torch.equal(out_tv, out_dali_tv) + + +def loop_images_test( + resize: int | Sequence[int], + max_size: int = None, + interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BILINEAR, + antialias: bool = False, +): + t, td = build_resize_transform(resize, max_size, interpolation, antialias) + loop_images_test_no_build(t, td, resize, max_size, interpolation, antialias) + + +@cartesian_params( + (512, 2048, ([512, 512]), ([2048, 2048])), + ("cpu", "gpu"), +) +def test_resize_and_tensor(resize, device): + # Resize with single int (preserve aspect ratio) + td = Compose( + [ + Resize(size=resize, device=device), + ToTensor(), + ] + ) + + img = Image.open(test_files[0]) + out = td(img) + + assert isinstance(out, torch.Tensor), f"Should be torch.Tensor type is {type(out)}" + assert torch.all(out <= 1).item(), "Tensor elements should be <0;1>" + + +@params(512, 2048, ([512, 512]), ([2048, 2048])) +def test_resize_sizes(resize): + # Resize with single int (preserve aspect ratio) + loop_images_test(resize=resize) + + +@params((480, 512), (100, 124), (None, 512), (1024, 512), ([256, 256], 512), (None, None)) +def test_resize_max_sizes(resize, max_size): + # Resize with single int (preserve aspect ratio) + if resize is not None and max_size is not None and np.min(np.array(resize, int)) > max_size: + + """ + with assert_raises(ValueError): + _ = transforms.Resize(resize, max_size) + This exception is called later - when executing the operation + """ + + with assert_raises(ValueError): + _ = Compose( + [ + Resize(resize, max_size=max_size), + ] + ) + return + if resize is None and max_size is None: + with assert_raises(ValueError): + _ = transforms.Resize(resize, max_size) + + with assert_raises(ValueError): + _ = Compose( + [ + Resize(resize, max_size=max_size), + ] + ) + return + + if isinstance(resize, Sequence) and len(resize) != 1 and max_size is not None: + """ + with assert_raises(ValueError): + _ = transforms.Resize(resize, max_size) + This exception is called later - when executing the operation + """ + + with assert_raises(ValueError): + _ = Compose( + [ + Resize(resize, max_size=max_size), + ] + ) + return + + loop_images_test(resize=resize, max_size=max_size) + + +@params( + ([512, 512], transforms.InterpolationMode.NEAREST), + (1024, transforms.InterpolationMode.NEAREST_EXACT), + ([256, 256], transforms.InterpolationMode.BILINEAR), + (640, transforms.InterpolationMode.BICUBIC), +) +def test_resize_interploation(resize, interpolation): + loop_images_test(resize=resize, interpolation=interpolation) + + +@params((512, True), (2048, True), ([512, 512], True), ([2048, 2048], True)) +def test_resize_antialiasing(resize, antialiasing): + loop_images_test(resize=resize, antialias=antialiasing) From e4ade7a24563c268a34dc586f014e65c453e54ee Mon Sep 17 00:00:00 2001 From: Marek Dabek Date: Wed, 4 Mar 2026 14:14:42 +0100 Subject: [PATCH 2/7] Review fixes Signed-off-by: Marek Dabek --- .../dali/experimental/torchvision/__init__.py | 2 - .../experimental/torchvision/v2/compose.py | 64 +++++++++++-------- .../dali/experimental/torchvision/v2/flips.py | 2 +- .../torchvision/v2/functional/flips.py | 7 +- .../torchvision/v2/functional/resize.py | 3 - .../experimental/torchvision/v2/operator.py | 33 ++++++---- .../experimental/torchvision/v2/tensor.py | 32 ---------- .../python/torchvision/test_tv_compose.py | 42 +++++++++++- .../test/python/torchvision/test_tv_resize.py | 43 ++----------- 9 files changed, 106 insertions(+), 122 deletions(-) delete mode 100644 dali/python/nvidia/dali/experimental/torchvision/v2/tensor.py diff --git a/dali/python/nvidia/dali/experimental/torchvision/__init__.py b/dali/python/nvidia/dali/experimental/torchvision/__init__.py index 31d1c4729b7..a1b0a869eaa 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/__init__.py +++ b/dali/python/nvidia/dali/experimental/torchvision/__init__.py @@ -15,12 +15,10 @@ from .v2.compose import Compose from .v2.flips import RandomHorizontalFlip, RandomVerticalFlip from .v2.resize import Resize -from .v2.tensor import ToTensor __all__ = [ "Compose", "RandomHorizontalFlip", "RandomVerticalFlip", "Resize", - "ToTensor", ] diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py b/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py index 9ee06378aa1..ebdc457f941 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Sequence, Callable +from typing import List, Sequence, Callable, Union import nvidia.dali.fn as fn from nvidia.dali.pipeline import pipeline_def from nvidia.dali.data_node import DataNode as _DataNode from nvidia.dali.backend import TensorListCPU, TensorListGPU -from .tensor import ToTensor from .operator import VerificationTensorOrImage import numpy as np @@ -28,7 +27,7 @@ import torch DEFAULT_BATCH_SIZE = 16 -DEFAULT_NUM_THREADS = multiprocessing.cpu_count() // 2 +DEFAULT_NUM_THREADS = 1 if multiprocessing.cpu_count() == 1 else multiprocessing.cpu_count() // 2 def _to_torch_tensor(tensor_or_tl: TensorListGPU | TensorListCPU) -> torch.Tensor: @@ -40,7 +39,9 @@ def _to_torch_tensor(tensor_or_tl: TensorListGPU | TensorListCPU) -> torch.Tenso return torch.from_dlpack(dali_tensor) -def to_torch_tensor(tensor_or_tl: tuple | TensorListGPU | TensorListCPU) -> torch.Tensor: +def to_torch_tensor( + x: Union[tuple, "TensorListGPU", "TensorListCPU"], +) -> Union[torch.Tensor, tuple]: """ Converts a DALI tensor or tensor list to a PyTorch tensor. @@ -49,15 +50,14 @@ def to_torch_tensor(tensor_or_tl: tuple | TensorListGPU | TensorListCPU) -> torc tensor_or_tl : tuple, TensorListGPU, TensorListCPU DALI tensor or tensor list. """ - if isinstance(tensor_or_tl, tuple) and len(tensor_or_tl) > 1: - tl = [] - for elem in tensor_or_tl: - tl.append(_to_torch_tensor(elem)) - return tuple(tl) + if isinstance(x, (TensorListGPU, TensorListCPU)): + return to_torch_tensor(x.as_tensor()) + elif isinstance(x, tuple): + if len(x) == 1: + return _to_torch_tensor(x[0]) + return tuple(to_torch_tensor(elem) for elem in x) else: - if len(tensor_or_tl) == 1: - tensor_or_tl = tensor_or_tl[0] - return _to_torch_tensor(tensor_or_tl) + return torch.from_dlpack(x) @pipeline_def(enable_conditionals=True, exec_dynamic=True, prefetch_queue_depth=1) @@ -74,13 +74,11 @@ def _pipeline_function(op_list, layout="HWC"): """ input_node = fn.external_source(name="input_data", no_copy=True, layout=layout) for op in op_list: - if isinstance(op, ToTensor) and op != op_list[-1]: - raise NotImplementedError("ToTensor can only be the last operation in the pipeline") input_node = op(input_node) return input_node -class PipelineLayouted: +class PipelineWithLayout: """Base class for pipeline layouts. This class is a base class for DALI pipelines with a specific layout. It is used to handle @@ -109,21 +107,33 @@ def __init__( num_threads: int = DEFAULT_NUM_THREADS, **dali_pipeline_kwargs, ): - self.convert_to_tensor = True if isinstance(op_list[-1], ToTensor) else False + # TODO: + # convert_to_tensor is currently not supported and requires an user's effort + # to convert to tensor + # ToTensor is deprecated and according to: + # https://docs.pytorch.org/vision/stable/_modules/torchvision/transforms/v2/_deprecated.html#ToTensor + # should be replaced with: + # v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]) + # + # self.convert_to_tensor = True if isinstance(op_list[-1], ToTensor) else False + self.convert_to_tensor = False self.pipe = _pipeline_function( op_list, layout=layout, batch_size=batch_size, num_threads=num_threads, - *dali_pipeline_kwargs, + **dali_pipeline_kwargs, ) def run(self, data_input): output = None - stream = torch.cuda.Stream(0) - with torch.cuda.stream(stream): - output = self.pipe.run(stream, input_data=data_input) + if torch.cuda.is_available(): + stream = torch.cuda.Stream(0) + with torch.cuda.stream(stream): + output = self.pipe.run(stream, input_data=data_input) + else: + output = self.pipe.run(input_data=data_input) if output is None: return output @@ -144,7 +154,7 @@ def is_conversion_to_tensor(self) -> bool: return self.convert_to_tensor -class PipelineHWC(PipelineLayouted): +class PipelineHWC(PipelineWithLayout): """Handles ``PIL.Image`` in HWC format. This class prepares data to be passed to a DALI pipeline, runs the pipeline and converts @@ -174,7 +184,7 @@ def __init__( layout="HWC", batch_size=batch_size, num_threads=num_threads, - *dali_pipeline_kwargs, + **dali_pipeline_kwargs, ) def _convert_tensor_to_image(self, in_tensor: torch.Tensor): @@ -191,7 +201,7 @@ def _convert_tensor_to_image(self, in_tensor: torch.Tensor): mode = "RGBA" else: raise ValueError( - f"Unsupported number of channels: {in_tensor.shape[channels]}. Should be 1 or 3." + f"Unsupported number of channels: {in_tensor.shape[channels]}. Should be 1, 3 or 4." ) # We need to convert tensor to CPU, PIL does not support CUDA tensors return Image.fromarray(in_tensor.cpu().numpy(), mode=mode) @@ -233,7 +243,7 @@ def get_channel_reverse_idx(self) -> int: return -1 -class PipelineCHW(PipelineLayouted): +class PipelineCHW(PipelineWithLayout): """Handles ``torch.Tensors`` in CHW format. This class prepares data to be passed to a DALI pipeline and runs the pipeline, converting @@ -263,7 +273,7 @@ def __init__( layout="CHW", batch_size=batch_size, num_threads=num_threads, - *dali_pipeline_kwargs, + **dali_pipeline_kwargs, ) def run(self, data_input): @@ -327,11 +337,11 @@ def __init__( def _build_pipeline(self, data_input): if isinstance(data_input, Image.Image): self.active_pipeline = PipelineHWC( - self.op_list, self.batch_size, self.num_threads, *self.dali_pipeline_kwargs + self.op_list, self.batch_size, self.num_threads, **self.dali_pipeline_kwargs ) elif isinstance(data_input, torch.Tensor): self.active_pipeline = PipelineCHW( - self.op_list, self.batch_size, self.num_threads, *self.dali_pipeline_kwargs + self.op_list, self.batch_size, self.num_threads, **self.dali_pipeline_kwargs ) else: raise ValueError("Currently only PILImages and torch.Tensors are supported") diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/flips.py b/dali/python/nvidia/dali/experimental/torchvision/v2/flips.py index adbe02ae81a..a821a4aa415 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/flips.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/flips.py @@ -26,7 +26,7 @@ class RandomFlip(Operator): p : float Probability of the image being flipped. Default value is 0.5 horizontal : int - Flip the horizontal dimension. + Flip the horizontal dimension if 1, vertical if 0 device : Literal["cpu", "gpu"], optional, default = "cpu" Device to use for the flip. Can be ``"cpu"`` or ``"gpu"``. """ diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py index f60627f945a..ea1b15440f2 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py @@ -15,9 +15,6 @@ import nvidia.dali.experimental.dynamic as ndd -import sys - -sys.path.append("..") from ..operator import adjust_input # noqa: E402 @@ -25,7 +22,7 @@ def horizontal_flip(inpt: ndd.Tensor) -> ndd.Tensor: """ Horizontally flips the given tensor. - Refer to ``HorizontalFlip`` for more details. + Refer to `HorizontalFlip` for more details. """ return ndd.flip(inpt, horizontal=1, vertical=0) @@ -34,6 +31,6 @@ def horizontal_flip(inpt: ndd.Tensor) -> ndd.Tensor: def vertical_flip(inpt: ndd.Tensor) -> ndd.Tensor: """ Vertically flips the given tensor. - Refer to ``VerticalFlip`` for more details. + Refer to `VerticalFlip` for more details. """ return ndd.flip(inpt, horizontal=0, vertical=1) diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py index de689549330..38f30f95bc3 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py @@ -17,9 +17,6 @@ import nvidia.dali.experimental.dynamic as ndd from torchvision.transforms import InterpolationMode -import sys - -sys.path.append("..") from ..operator import adjust_input # noqa: E402 from ..resize import Resize # noqa: E402 diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/operator.py b/dali/python/nvidia/dali/experimental/torchvision/v2/operator.py index b84e76149a9..7b6d83d79f5 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/operator.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/operator.py @@ -82,7 +82,7 @@ def verify(cls, data): class VerificationChannelCount(DataVerificationRule): """ - Verify the number of channels for the input data. + Verify if input data has <= 4 channels. More channels are not supported in Torchvision Parameters ---------- @@ -100,7 +100,8 @@ def verify(cls, data): ): raise ValueError( f"Input should be in CHW if Tensor. \ - Supported channels: {VerificationChannelCount.CHANNELS} is {data.shape[-3]}" + Supports up to {VerificationChannelCount.CHANNELS[-1]} channels, \ + got: {data.shape[-3]} channels" ) @@ -111,7 +112,7 @@ class VerifyIfPositive(ArgumentVerificationRule): Parameters ---------- values : any - Value to verify. Should be a positive number. + Value to verify. Should be a positive numbers. """ @classmethod @@ -119,28 +120,31 @@ def verify(cls, *, values, name, **_) -> None: if isinstance(values, (int, float)) and values <= 0: raise ValueError(f"Value {name} must be positive, got {values}") elif isinstance(values, (list, tuple)) and any(k <= 0 for k in values): - raise ValueError(f"Values {name} should be positive number, got {values}") + raise ValueError(f"Values {name} should be positive numbers, got {values}") -class VerifyIfOrderedPair(ArgumentVerificationRule): +class VerifyIfRange(ArgumentVerificationRule): """ - Verify if the value is an ordered pair. + Verify if the value is a correct range: (min, max) Parameters ---------- values : any - Value to verify. Should be an ordered pair. + Value to verify. Should be a range: (min, max) """ @classmethod def verify(cls, *, values, name, **_) -> None: if isinstance(values, (list, tuple)) and len(values) == 2 and values[0] > values[1]: - raise ValueError(f"Values {name} should be ordered, got {values}") + raise ValueError(f"Values {name} should be (min, max), got {values}") -class VerificationSize(ArgumentVerificationRule): +class VerifSizeDescriptor(ArgumentVerificationRule): """ - Verify if the value is an integer or a sequence of length 1 or 2. + Verify if the value can describe a size argument, which is: + - an integer + - or a sequence of length of 1, + - or a sequence of length of 2 Parameters ---------- @@ -151,7 +155,7 @@ class VerificationSize(ArgumentVerificationRule): @classmethod def verify(cls, *, size, **_) -> None: if not isinstance(size, (int, list, tuple)): - raise TypeError(f"Size must be int or sequence, got {type(size)}") + raise TypeError(f"Size must be int, list, or tuple, got {type(size)}") elif isinstance(size, (list, tuple)) and len(size) > 2: raise ValueError(f"Size sequence must have length 1 or 2, got {len(size)}") VerifyIfPositive.verify(values=size, name="size") @@ -202,6 +206,8 @@ def _kernel(self, data_input): def __call__(self, data_input): + Operator.verify_data(data_input) + if self.device == "gpu": data_input = data_input.gpu() @@ -251,7 +257,7 @@ def transform_input(inpt) -> ndd.Tensor | ndd.Batch: # WAR: _input = ndd.as_batch(ndd.as_tensor(inpt), layout="CHW") else: - raise TypeError(f"Tensor has < 3 dimensions: {inpt.ndim} / {inpt.shape}") + raise TypeError(f"Tensor has < 3 dimensions: {inpt.ndim}, shape: {inpt.shape}") else: raise TypeError(f"Data type: {type(inpt)} is not supported") @@ -280,7 +286,8 @@ def adjust_output( elif not isinstance(output, ndd.Tensor): raise TypeError(f"Invalid output type: {type(output)}") - # This is WAR for DLPpack not supporting pinned memory + # This is WAR for DLPpack not supporting pinned memory, see: + # https://github.com/pytorch/pytorch/issues/136250H if output.device.device_type == "cpu": output = np.asarray(output) diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/tensor.py b/dali/python/nvidia/dali/experimental/torchvision/v2/tensor.py deleted file mode 100644 index cc6514566dd..00000000000 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/tensor.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class ToTensor: - """ - Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor - of shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes - (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8 - - [DEPRECATED but used] - """ - - def __call__(self, data_input): - """ - Performs to tensor conversion it only converts to float, the remaining part is being done - in Compose.__call__ - """ - # TODO: if data_input.dtype==types.DALIDataType.UINT8: - data_input = data_input / 255.0 - return data_input diff --git a/dali/test/python/torchvision/test_tv_compose.py b/dali/test/python/torchvision/test_tv_compose.py index 1b06434ea41..9c57684975c 100644 --- a/dali/test/python/torchvision/test_tv_compose.py +++ b/dali/test/python/torchvision/test_tv_compose.py @@ -14,7 +14,12 @@ import os -from nvidia.dali.experimental.torchvision import Compose, RandomHorizontalFlip +from nvidia.dali.experimental.torchvision import ( + Compose, + RandomHorizontalFlip, + RandomVerticalFlip, + Resize, +) from nose_utils import assert_raises import numpy as np @@ -55,6 +60,22 @@ def test_compose_tensor(): assert torch.equal(dali_out, tv_out) +def test_compose_multi_tensor(): + test_tensor = make_test_tensor(shape=(5, 5, 5, 3)) + dali_pipeline = Compose( + [Resize(size=(15, 15)), RandomHorizontalFlip(p=1.0), RandomVerticalFlip(p=1.0)], + batch_size=test_tensor.shape[0], + ) + dali_out = dali_pipeline(test_tensor) + tv_pipeline = tv.Compose( + [tv.Resize(size=(15, 15)), tv.RandomHorizontalFlip(p=1.0), tv.RandomVerticalFlip(p=1.0)] + ) + tv_out = tv_pipeline(test_tensor) + + assert isinstance(dali_out, torch.Tensor) + assert torch.equal(dali_out, tv_out) + + def test_compose_invalid_batch_tensor(): test_tensor = make_test_tensor(shape=(5, 5, 5, 1)) with assert_raises(RuntimeError): @@ -64,7 +85,24 @@ def test_compose_invalid_batch_tensor(): def test_compose_images(): dali_transform = Compose([RandomHorizontalFlip(p=1.0)]) - tv_transform = tv.RandomHorizontalFlip(p=1.0) + tv_transform = tv.Compose([tv.RandomHorizontalFlip(p=1.0)]) + + for fn in test_files: + img = Image.open(fn) + out_dali_img = dali_transform(img) + + assert isinstance(out_dali_img, Image.Image) + + tensor_dali_tv = transforms.functional.pil_to_tensor(out_dali_img) + tensor_tv = transforms.functional.pil_to_tensor(tv_transform(img)) + + assert tensor_dali_tv.shape == tensor_tv.shape + assert torch.equal(tensor_dali_tv, tensor_tv) + + +def test_compose_images_multi(): + dali_transform = Compose([RandomVerticalFlip(p=1.0), RandomHorizontalFlip(p=1.0)]) + tv_transform = tv.Compose([tv.RandomVerticalFlip(p=1.0), tv.RandomHorizontalFlip(p=1.0)]) for fn in test_files: img = Image.open(fn) diff --git a/dali/test/python/torchvision/test_tv_resize.py b/dali/test/python/torchvision/test_tv_resize.py index d1f13832578..f323dee4f78 100644 --- a/dali/test/python/torchvision/test_tv_resize.py +++ b/dali/test/python/torchvision/test_tv_resize.py @@ -23,7 +23,7 @@ import torchvision.transforms.v2 as transforms import torchvision.transforms.v2.functional as fn_tv -from nvidia.dali.experimental.torchvision import Resize, Compose, ToTensor +from nvidia.dali.experimental.torchvision import Resize, Compose import nvidia.dali.experimental.torchvision.v2.functional as fn_dali @@ -99,25 +99,14 @@ def loop_images_test_no_build( out_tv = transforms.functional.pil_to_tensor(t(img)).unsqueeze(0).permute(0, 2, 3, 1) out_dali_tv = transforms.functional.pil_to_tensor(td(img)).unsqueeze(0).permute(0, 2, 3, 1) - tv_shape_lower = torch.Size([out_tv.shape[1] - 1, out_tv.shape[2] - 1]) - tv_shape_upper = torch.Size([out_tv.shape[1] + 1, out_tv.shape[2] + 1]) - tv_fn_shape_lower = torch.Size([out_fn.shape[1] - 1, out_fn.shape[2] - 1]) - tv_fn_shape_upper = torch.Size([out_fn.shape[1] + 1, out_fn.shape[2] + 1]) - - assert ( - tv_shape_lower[0] <= out_dali_tv.shape[1] <= tv_shape_upper[0] - ), f"Should be:{out_tv.shape} is:{out_dali_tv.shape}" - assert ( - tv_shape_lower[1] <= out_dali_tv.shape[2] <= tv_shape_upper[1] + assert torch.allclose( + torch.tensor(out_tv.shape[1:3]), torch.tensor(out_dali_tv.shape[1:3]), rtol=0, atol=1 ), f"Should be:{out_tv.shape} is:{out_dali_tv.shape}" + assert torch.allclose( + torch.tensor(out_fn.shape[1:3]), torch.tensor(out_dali_fn.shape[1:3]), rtol=0, atol=1 + ), f"Should be:{out_fn.shape} is:{out_dali_fn.shape}" - assert ( - tv_fn_shape_lower[0] <= out_dali_fn.shape[1] <= tv_fn_shape_upper[0] - ), f"Should be:{out_tv.shape} is:{out_dali_fn.shape}" - assert ( - tv_fn_shape_lower[1] <= out_dali_fn.shape[2] <= tv_fn_shape_upper[1] - ), f"Should be:{out_tv.shape} is:{out_dali_fn.shape}" # assert torch.equal(out_tv, out_dali_tv) @@ -131,26 +120,6 @@ def loop_images_test( loop_images_test_no_build(t, td, resize, max_size, interpolation, antialias) -@cartesian_params( - (512, 2048, ([512, 512]), ([2048, 2048])), - ("cpu", "gpu"), -) -def test_resize_and_tensor(resize, device): - # Resize with single int (preserve aspect ratio) - td = Compose( - [ - Resize(size=resize, device=device), - ToTensor(), - ] - ) - - img = Image.open(test_files[0]) - out = td(img) - - assert isinstance(out, torch.Tensor), f"Should be torch.Tensor type is {type(out)}" - assert torch.all(out <= 1).item(), "Tensor elements should be <0;1>" - - @params(512, 2048, ([512, 512]), ([2048, 2048])) def test_resize_sizes(resize): # Resize with single int (preserve aspect ratio) From bfed7216d4cf2051e1b2f450ffc341832d227f8d Mon Sep 17 00:00:00 2001 From: Marek Dabek Date: Thu, 5 Mar 2026 15:30:02 +0100 Subject: [PATCH 3/7] Review fixes - 2 Modified resize operator size calculation Signed-off-by: Marek Dabek --- .../torchvision/v2/functional/resize.py | 16 ++- .../experimental/torchvision/v2/operator.py | 2 +- .../experimental/torchvision/v2/resize.py | 48 ++++++- .../test/python/torchvision/test_tv_resize.py | 129 +++++++++++++----- 4 files changed, 157 insertions(+), 38 deletions(-) diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py index 38f30f95bc3..b76a86e7454 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py @@ -40,8 +40,22 @@ def resize( effective_size, mode = Resize.infer_effective_size(size, max_size) interpolation = Resize.interpolation_modes[interpolation] + if isinstance(img, ndd.Tensor): + img_shape = img.shape + elif isinstance(img, ndd.Batch): + img_shape = img.shape[0] # Batches have uniform layout + else: + raise TypeError(f"Input must be ndd.Tensor or ndd.Batch got {type(img)}") + + if img.layout in ["HWC", "NHWC"]: + original_h = img_shape[-3] + original_w = img_shape[-2] + elif img.layout in ["CHW", "NCHW"]: + original_h = img_shape[-2] + original_w = img_shape[-1] + target_h, target_w = Resize.calculate_target_size( - img.shape, effective_size, max_size, size is None + (original_h, original_w), effective_size, max_size, size is None ) # Shorter edge limited by max size diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/operator.py b/dali/python/nvidia/dali/experimental/torchvision/v2/operator.py index 7b6d83d79f5..c202dcc8dfb 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/operator.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/operator.py @@ -206,7 +206,7 @@ def _kernel(self, data_input): def __call__(self, data_input): - Operator.verify_data(data_input) + type(self).verify_data(data_input) if self.device == "gpu": data_input = data_input.gpu() diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py b/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py index 5c1478e959e..3d18a75ebe8 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py @@ -24,6 +24,36 @@ import numpy as np +def get_inputHW(data_input): + """ + Gets the height and width of the input data. + + Parameters + ---------- + data_input : Tensor + Input data to get the height and width of. + + Returns + ------- + input_height : int + Height of the input data. + input_width : int + Width of the input data. + """ + layout = data_input.property("layout")[0] + + # CWH + if layout == np.frombuffer(bytes("C", "utf-8"), dtype=np.uint8)[0]: + input_height = data_input.shape()[-1] + input_width = data_input.shape()[-2] + # HWC + else: + input_height = data_input.shape()[-3] + input_width = data_input.shape()[-2] + + return input_height, input_width, data_input + + class VerificationSize(ArgumentVerificationRule): @classmethod def verify(cls, *, size, max_size, interpolation, **_): @@ -84,7 +114,9 @@ class Resize(Operator): InterpolationMode.HAMMING: DALIInterpType.INTERP_GAUSSIAN, # TODO: InterpolationMode.LANCZOS: DALIInterpType.INTERP_LANCZOS3, } + arg_rules = [VerificationSize] + preprocess_data = get_inputHW @classmethod def infer_effective_size( @@ -120,6 +152,7 @@ def calculate_target_size( ): orig_h = orig_size[0] orig_w = orig_size[1] + target_h = effective_size[0] target_w = effective_size[1] @@ -160,15 +193,24 @@ def _kernel(self, data_input): with ``torchvision.transforms.Resize`` documentation and applies DALI operator on the ``data_input``. """ + input_height, input_width, data_input = data_input target_h, target_w = Resize.calculate_target_size( - data_input.shape(), self.effective_size, self.max_size, self.size is None + orig_size=(input_height, input_width), + effective_size=self.effective_size, + max_size=self.max_size, + no_size=self.size is None, ) # Shorter edge limited by max size if self.mode == "resize_shorter": return fn.resize( - data_input, device=self.device, resize_shorter=target_h, max_size=self.max_size + data_input, + device=self.device, + resize_shorter=target_h, + max_size=self.max_size, + antialias=self.antialias, + interp_type=self.interpolation, ) return fn.resize( @@ -179,4 +221,6 @@ def _kernel(self, data_input): fn.cast(target_w, dtype=dali.types.FLOAT), ), mode=self.mode, + antialias=self.antialias, + interp_type=self.interpolation, ) diff --git a/dali/test/python/torchvision/test_tv_resize.py b/dali/test/python/torchvision/test_tv_resize.py index f323dee4f78..d2ae2ae23da 100644 --- a/dali/test/python/torchvision/test_tv_resize.py +++ b/dali/test/python/torchvision/test_tv_resize.py @@ -13,7 +13,7 @@ # limitations under the License. import os -from typing import Sequence +from typing import Sequence, Literal, Union import numpy as np from nose2.tools import params, cartesian_params @@ -50,6 +50,7 @@ def build_resize_transform( max_size: int = None, interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BILINEAR, antialias: bool = False, + device: Literal["cpu", "gpu"] = "cpu", ): t = transforms.Compose( [ @@ -61,13 +62,59 @@ def build_resize_transform( td = Compose( [ Resize( - size=resize, max_size=max_size, interpolation=interpolation, antialias=antialias + size=resize, + max_size=max_size, + interpolation=interpolation, + antialias=antialias, + device=device, ), ] ) return t, td +def _internal_loop( + input_data: Union[Image.Image, torch.Tensor], + t: transforms.Resize, + td: Resize, + resize: int | Sequence[int], + max_size: int = None, + interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BILINEAR, + antialias: bool = False, +): + out_fn = fn_tv.resize( + input_data, + size=resize, + max_size=max_size, + interpolation=interpolation, + antialias=antialias, + ) + out_dali_fn = fn_dali.resize( + input_data, + size=resize, + max_size=max_size, + interpolation=interpolation, + antialias=antialias, + ) + out_tv = t(input_data) + out_dali_tv = td(input_data) + + if isinstance(input_data, Image.Image): + out_tv = transforms.functional.pil_to_tensor(out_tv).unsqueeze(0).permute(0, 2, 3, 1) + out_dali_tv = ( + transforms.functional.pil_to_tensor(out_dali_tv).unsqueeze(0).permute(0, 2, 3, 1) + ) + out_fn = transforms.functional.pil_to_tensor(out_fn) + out_dali_fn = transforms.functional.pil_to_tensor(out_dali_fn) + + assert torch.allclose( + torch.tensor(out_tv.shape[1:3]), torch.tensor(out_dali_tv.shape[1:3]), rtol=0, atol=1 + ), f"Should be:{out_tv.shape} is:{out_dali_tv.shape}" + assert torch.allclose( + torch.tensor(out_fn.shape[1:3]), torch.tensor(out_dali_fn.shape[1:3]), rtol=0, atol=1 + ), f"Should be:{out_fn.shape} is:{out_dali_fn.shape}" + + def loop_images_test_no_build( t: transforms.Resize, td: Resize, @@ -78,36 +125,43 @@ def loop_images_test_no_build( ): for fn in test_files: img = Image.open(fn) - out_fn = transforms.functional.pil_to_tensor( - fn_tv.resize( - img, - size=resize, - max_size=max_size, - interpolation=interpolation, - antialias=antialias, - ) - ) - out_dali_fn = transforms.functional.pil_to_tensor( - fn_dali.resize( - img, - size=resize, - max_size=max_size, - interpolation=interpolation, - antialias=antialias, - ) - ) + _internal_loop(img, t, td, resize, max_size, interpolation, antialias) + # assert torch.equal(out_tv, out_dali_tv) - out_tv = transforms.functional.pil_to_tensor(t(img)).unsqueeze(0).permute(0, 2, 3, 1) - out_dali_tv = transforms.functional.pil_to_tensor(td(img)).unsqueeze(0).permute(0, 2, 3, 1) - assert torch.allclose( - torch.tensor(out_tv.shape[1:3]), torch.tensor(out_dali_tv.shape[1:3]), rtol=0, atol=1 - ), f"Should be:{out_tv.shape} is:{out_dali_tv.shape}" - assert torch.allclose( - torch.tensor(out_fn.shape[1:3]), torch.tensor(out_dali_fn.shape[1:3]), rtol=0, atol=1 - ), f"Should be:{out_fn.shape} is:{out_dali_fn.shape}" +def build_tensors(max_size: int = 512, channels: int = 3): + h = torch.randint(10, max_size, (1,)).item() + w = torch.randint(10, max_size, (1,)).item() + tensors = [ + torch.ones((channels, max_size, max_size)), + torch.ones((1, channels, max_size, max_size)), + torch.ones((10, channels, max_size, max_size)), + torch.ones((channels, max_size // 2, max_size)), + torch.ones((1, channels, max_size // 2, max_size)), + torch.ones((10, channels, max_size // 2, max_size)), + torch.ones((channels, max_size, max_size // 2)), + torch.ones((1, channels, max_size, max_size // 2)), + torch.ones((10, channels, max_size, max_size // 2)), + torch.ones((channels, h, w)), + torch.ones((1, channels, h, w)), + torch.ones((10, channels, h, w)), + ] + + return tensors + + +def loop_tensors_test( + resize: int | Sequence[int], + max_size: int = None, + interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BILINEAR, + antialias: bool = False, + device: Literal["cpu", "gpu"] = "cpu", +): + t, td = build_resize_transform(resize, max_size, interpolation, antialias, device) + tensors = build_tensors() - # assert torch.equal(out_tv, out_dali_tv) + for tn in tensors: + _internal_loop(tn, t, td, resize, max_size, interpolation, antialias) def loop_images_test( @@ -115,15 +169,22 @@ def loop_images_test( max_size: int = None, interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BILINEAR, antialias: bool = False, + device: Literal["cpu", "gpu"] = "cpu", ): - t, td = build_resize_transform(resize, max_size, interpolation, antialias) + t, td = build_resize_transform(resize, max_size, interpolation, antialias, device) loop_images_test_no_build(t, td, resize, max_size, interpolation, antialias) -@params(512, 2048, ([512, 512]), ([2048, 2048])) -def test_resize_sizes(resize): +@cartesian_params((512, 2048, ([512, 512]), ([2048, 2048])), ("cpu", "gpu")) +def test_resize_sizes_images(resize, device): + # Resize with single int (preserve aspect ratio) + loop_images_test(resize=resize, device=device) + + +@cartesian_params((512, 2048, ([512, 512]), ([2048, 2048])), ("cpu", "gpu")) +def test_resize_sizes_tensors(resize, device): # Resize with single int (preserve aspect ratio) - loop_images_test(resize=resize) + loop_tensors_test(resize=resize, device=device) @params((480, 512), (100, 124), (None, 512), (1024, 512), ([256, 256], 512), (None, None)) @@ -180,7 +241,7 @@ def test_resize_max_sizes(resize, max_size): ([256, 256], transforms.InterpolationMode.BILINEAR), (640, transforms.InterpolationMode.BICUBIC), ) -def test_resize_interploation(resize, interpolation): +def test_resize_interpoation(resize, interpolation): loop_images_test(resize=resize, interpolation=interpolation) From 9ac174a13a1d9c0633a2c1624351460a29ee6acc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20D=C4=85bek?= Date: Thu, 5 Mar 2026 16:00:42 +0100 Subject: [PATCH 4/7] Apply suggestion from @greptile-apps[bot] Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Marek Dabek --- .../dali/experimental/torchvision/v2/functional/resize.py | 4 ++++ .../nvidia/dali/experimental/torchvision/v2/resize.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py index b76a86e7454..24913869f3e 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py @@ -53,6 +53,10 @@ def resize( elif img.layout in ["CHW", "NCHW"]: original_h = img_shape[-2] original_w = img_shape[-1] + else: + raise ValueError( + f"Unsupported layout: {img.layout!r}. Expected one of HWC, NHWC, CHW, NCHW." + ) target_h, target_w = Resize.calculate_target_size( (original_h, original_w), effective_size, max_size, size is None diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py b/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py index 3d18a75ebe8..50d99b82164 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py @@ -42,10 +42,10 @@ def get_inputHW(data_input): """ layout = data_input.property("layout")[0] - # CWH + # CHW if layout == np.frombuffer(bytes("C", "utf-8"), dtype=np.uint8)[0]: - input_height = data_input.shape()[-1] - input_width = data_input.shape()[-2] + input_height = data_input.shape()[-2] + input_width = data_input.shape()[-1] # HWC else: input_height = data_input.shape()[-3] From b12f41506fadc2d42d334de1563b11b5a646a7c7 Mon Sep 17 00:00:00 2001 From: Marek Dabek Date: Fri, 6 Mar 2026 16:50:50 +0100 Subject: [PATCH 5/7] Fixed review errors and added resize tests. Moved resize output size calculation to a Resize class Signed-off-by: Marek Dabek --- .../experimental/torchvision/v2/compose.py | 47 +++-- .../dali/experimental/torchvision/v2/flips.py | 9 +- .../torchvision/v2/functional/flips.py | 4 +- .../torchvision/v2/functional/resize.py | 52 +++--- .../experimental/torchvision/v2/operator.py | 24 ++- .../experimental/torchvision/v2/resize.py | 169 +++++++++++++----- .../python/torchvision/test_tv_compose.py | 100 +++++++++-- dali/test/python/torchvision/test_tv_flips.py | 4 +- .../test/python/torchvision/test_tv_resize.py | 75 +++++--- 9 files changed, 344 insertions(+), 140 deletions(-) diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py b/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py index ebdc457f941..170ab0754f5 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod from typing import List, Sequence, Callable, Union import nvidia.dali.fn as fn @@ -78,7 +79,7 @@ def _pipeline_function(op_list, layout="HWC"): return input_node -class PipelineWithLayout: +class PipelineWithLayout(ABC): """Base class for pipeline layouts. This class is a base class for DALI pipelines with a specific layout. It is used to handle @@ -99,6 +100,17 @@ class PipelineWithLayout: Additional keyword arguments for the DALI pipeline. """ + def _cuda_run(self, data_input): + device_id = data_input.device.index if isinstance(data_input, torch.Tensor) else 0 + stream = torch.cuda.Stream(device=device_id) + with torch.cuda.stream(stream): + output = self.pipe.run(stream, input_data=data_input) + + return output + + def _cpu_run(self, data_input): + return self.pipe.run(input_data=data_input) + def __init__( self, op_list: List[Callable[..., Sequence[_DataNode] | _DataNode]], @@ -124,16 +136,11 @@ def __init__( num_threads=num_threads, **dali_pipeline_kwargs, ) + self._internal_run = self._cuda_run if torch.cuda.is_available() else self._cpu_run def run(self, data_input): - output = None - if torch.cuda.is_available(): - stream = torch.cuda.Stream(0) - with torch.cuda.stream(stream): - output = self.pipe.run(stream, input_data=data_input) - else: - output = self.pipe.run(input_data=data_input) + output = self._internal_run(data_input) if output is None: return output @@ -146,10 +153,15 @@ def run(self, data_input): return output + @abstractmethod def get_layout(self) -> str: ... + @abstractmethod def get_channel_reverse_idx(self) -> int: ... + @abstractmethod + def verify_layout(self, data) -> None: ... + def is_conversion_to_tensor(self) -> bool: return self.convert_to_tensor @@ -192,9 +204,14 @@ def _convert_tensor_to_image(self, in_tensor: torch.Tensor): channels = self.get_channel_reverse_idx() # TODO: consider when to convert to PIL.Image - e.g. if it make sense for channels < 3 - if in_tensor.shape[channels] == 1: + # There is noi certain method to determine if the tensor is HW, HWC, or NHWC. + # The method below checks if tensor's shape is HW or ...HWC with a single channel + if len(in_tensor.shape) == 2 or ( + len(in_tensor.shape) >= 3 and in_tensor.shape[channels] == 1 + ): mode = "L" - in_tensor = in_tensor.squeeze(-1) + if len(in_tensor.shape) != 2: + in_tensor = in_tensor.squeeze(-1) elif in_tensor.shape[channels] == 3: mode = "RGB" elif in_tensor.shape[channels] == 4: @@ -242,6 +259,10 @@ def get_layout(self) -> str: def get_channel_reverse_idx(self) -> int: return -1 + def verify_layout(self, data_input) -> None: + if not isinstance(data_input, Image.Image): + raise TypeError(f"The pipeline expects PIL.Images as input got {type(data_input)}") + class PipelineCHW(PipelineWithLayout): """Handles ``torch.Tensors`` in CHW format. @@ -300,6 +321,10 @@ def get_layout(self) -> str: def get_channel_reverse_idx(self) -> int: return -3 + def verify_layout(self, data_input) -> None: + if not isinstance(data_input, torch.Tensor): + raise TypeError(f"The pipeline expects torch.Tensor as input got {type(data_input)}") + class Compose: """ @@ -365,4 +390,6 @@ def __call__(self, data_input): if self.active_pipeline is None: self._build_pipeline(data_input) + self.active_pipeline.verify_layout(data_input) + return self.active_pipeline.run(data_input=data_input) diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/flips.py b/dali/python/nvidia/dali/experimental/torchvision/v2/flips.py index a821a4aa415..a3c6b209477 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/flips.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/flips.py @@ -25,16 +25,17 @@ class RandomFlip(Operator): ---------- p : float Probability of the image being flipped. Default value is 0.5 - horizontal : int - Flip the horizontal dimension if 1, vertical if 0 + horizontal : bool + Flip the horizontal dimension if True, vertical otherwise device : Literal["cpu", "gpu"], optional, default = "cpu" Device to use for the flip. Can be ``"cpu"`` or ``"gpu"``. """ - def __init__(self, p: float = 0.5, horizontal: int = 1, device: Literal["cpu", "gpu"] = "cpu"): + def __init__( + self, p: float = 0.5, horizontal: bool = True, device: Literal["cpu", "gpu"] = "cpu" + ): super().__init__(device=device) self.prob = p - self.device = device self.horizontal = horizontal def _kernel(self, data_input): diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py index ea1b15440f2..48e9760e5e5 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py @@ -19,7 +19,7 @@ @adjust_input -def horizontal_flip(inpt: ndd.Tensor) -> ndd.Tensor: +def horizontal_flip(inpt: ndd.Tensor | ndd.Batch) -> ndd.Tensor | ndd.Batch: """ Horizontally flips the given tensor. Refer to `HorizontalFlip` for more details. @@ -28,7 +28,7 @@ def horizontal_flip(inpt: ndd.Tensor) -> ndd.Tensor: @adjust_input -def vertical_flip(inpt: ndd.Tensor) -> ndd.Tensor: +def vertical_flip(inpt: ndd.Tensor | ndd.Batch) -> ndd.Tensor | ndd.Batch: """ Vertically flips the given tensor. Refer to `VerticalFlip` for more details. diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py index 24913869f3e..297217d9f11 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py @@ -13,23 +13,25 @@ # limitations under the License. from typing import Optional, List, Literal -from torch import Tensor import nvidia.dali.experimental.dynamic as ndd from torchvision.transforms import InterpolationMode +import torch +from PIL import Image + from ..operator import adjust_input # noqa: E402 from ..resize import Resize # noqa: E402 @adjust_input def resize( - img: Tensor, + inpt: Image.Image | torch.Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[bool] = True, device: Literal["cpu", "gpu"] = "cpu", -) -> Tensor: +) -> Image.Image | torch.Tensor: """ Please refer to the ``Resize`` operator for more details. """ @@ -37,47 +39,35 @@ def resize( size=size, max_size=max_size, interpolation=interpolation, antialias=antialias ) - effective_size, mode = Resize.infer_effective_size(size, max_size) + size_normalized = Resize.infer_effective_size(size, max_size) interpolation = Resize.interpolation_modes[interpolation] - if isinstance(img, ndd.Tensor): - img_shape = img.shape - elif isinstance(img, ndd.Batch): - img_shape = img.shape[0] # Batches have uniform layout + if isinstance(inpt, ndd.Tensor): + inpt_shape = inpt.shape + elif isinstance(inpt, ndd.Batch): + inpt_shape = inpt.shape[0] # Batches have uniform layout else: - raise TypeError(f"Input must be ndd.Tensor or ndd.Batch got {type(img)}") + raise TypeError(f"Input must be ndd.Tensor or ndd.Batch got {type(inpt)}") - if img.layout in ["HWC", "NHWC"]: - original_h = img_shape[-3] - original_w = img_shape[-2] - elif img.layout in ["CHW", "NCHW"]: - original_h = img_shape[-2] - original_w = img_shape[-1] + if inpt.layout in ["HWC", "NHWC"]: + original_h = inpt_shape[-3] + original_w = inpt_shape[-2] + elif inpt.layout in ["HW", "CHW", "NCHW"]: + original_h = inpt_shape[-2] + original_w = inpt_shape[-1] else: raise ValueError( - f"Unsupported layout: {img.layout!r}. Expected one of HWC, NHWC, CHW, NCHW." + f"Unsupported layout: {inpt.layout!r}. Expected one of HWC, NHWC, CHW, NCHW." ) - target_h, target_w = Resize.calculate_target_size( - (original_h, original_w), effective_size, max_size, size is None + target_h, target_w = Resize.calculate_target_size_dynamic_mode( + (original_h, original_w), size_normalized, max_size ) - # Shorter edge limited by max size - if mode == "resize_shorter": - return ndd.resize( - img, - device=device, - resize_shorter=target_h, - max_size=max_size, - interp_type=interpolation, - antialias=antialias, - ) - return ndd.resize( - img, + inpt, device=device, size=(target_h, target_w), - mode=mode, interp_type=interpolation, antialias=antialias, ) diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/operator.py b/dali/python/nvidia/dali/experimental/torchvision/v2/operator.py index c202dcc8dfb..8e612ac7479 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/operator.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/operator.py @@ -208,6 +208,8 @@ def __call__(self, data_input): type(self).verify_data(data_input) + # Original input is transfered to GPU, before being preprocess_data. + # The preprocess_data creates an arbitrary tuple if self.device == "gpu": data_input = data_input.gpu() @@ -239,22 +241,26 @@ def transform_input(inpt) -> ndd.Tensor | ndd.Batch: L, RGB, or RGBA mode - torch.Tensor: ndim==3 -> ndd.Tensor(layout = "CHW"), - ndim>3 -> ndd.Batch(layout="NCHW") + ndim>3 -> ndd.Batch(layout="CHW") (workaround for DALI-4566; intended: layout="NCHW") """ mode = "RGB" if isinstance(inpt, Image.Image): - _input = ndd.Tensor(np.array(inpt, copy=True), layout="HWC") - if _input.shape[-1] == 1: - mode = "L" - elif _input.shape[-1] == 4: - mode = "RGBA" + mode = inpt.mode + if mode == "L": + _input = ndd.Tensor(np.array(inpt, copy=True), layout="HW") + elif mode in ["RGB", "RGBA"]: # Modes RGB, RGBA + _input = ndd.Tensor(np.array(inpt, copy=True), layout="HWC") + else: + raise ValueError(f"Mode {mode} is not supported, expected, L, RGB, RGBA") elif isinstance(inpt, torch.Tensor): if inpt.ndim == 3: _input = ndd.Tensor(inpt, layout="CHW") elif inpt.ndim > 3: - # The following should work, bug: https://jirasw.nvidia.com/browse/DALI-4566 + # Creating baches of NHWC does not work, because of: + # https://jirasw.nvidia.com/browse/DALI-4566 + # It should be implemented as: # _input = ndd.as_batch(inpt, layout="NCHW") - # WAR: + # currently workarounded as: _input = ndd.as_batch(ndd.as_tensor(inpt), layout="CHW") else: raise TypeError(f"Tensor has < 3 dimensions: {inpt.ndim}, shape: {inpt.shape}") @@ -276,7 +282,7 @@ def adjust_output( """ if isinstance(inpt, Image.Image): if output.shape[-1] == 1: - output = np.asarray(output).squeeze(2) + output = np.asarray(output).squeeze(-1) mode = "L" return Image.fromarray(np.asarray(output), mode=mode) elif isinstance(inpt, torch.Tensor): diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py b/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py index 50d99b82164..89cba766d85 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py @@ -42,6 +42,10 @@ def get_inputHW(data_input): """ layout = data_input.property("layout")[0] + # If data layout is NHWC or NCHW, check the next character + if layout == np.frombuffer(bytes("N", "utf-8"), dtype=np.uint8)[0]: + layout = data_input.property("layout")[1] + # CHW if layout == np.frombuffer(bytes("C", "utf-8"), dtype=np.uint8)[0]: input_height = data_input.shape()[-2] @@ -107,14 +111,21 @@ class Resize(Operator): # 'NEAREST', 'NEAREST_EXACT', 'BILINEAR', 'BICUBIC', 'BOX', 'HAMMING', 'LANCZOS' interpolation_modes = { InterpolationMode.NEAREST: DALIInterpType.INTERP_NN, - InterpolationMode.NEAREST_EXACT: DALIInterpType.INTERP_NN, # TODO InterpolationMode.BILINEAR: DALIInterpType.INTERP_LINEAR, InterpolationMode.BICUBIC: DALIInterpType.INTERP_CUBIC, - InterpolationMode.BOX: DALIInterpType.INTERP_LINEAR, # TODO: - InterpolationMode.HAMMING: DALIInterpType.INTERP_GAUSSIAN, # TODO: InterpolationMode.LANCZOS: DALIInterpType.INTERP_LANCZOS3, + # Not supported, but need to be here to not generate ValueError during VerificationSize + InterpolationMode.NEAREST_EXACT: DALIInterpType.INTERP_NN, + InterpolationMode.BOX: DALIInterpType.INTERP_NN, + InterpolationMode.HAMMING: DALIInterpType.INTERP_NN, } + not_supported_interpolation_modes = [ + InterpolationMode.NEAREST_EXACT, + InterpolationMode.BOX, + InterpolationMode.HAMMING, + ] + arg_rules = [VerificationSize] preprocess_data = get_inputHW @@ -123,47 +134,117 @@ def infer_effective_size( cls, size: Optional[int | Sequence[int]], max_size: Optional[int] = None, - ) -> int | Sequence[int]: + ) -> Optional[int | Sequence[int]]: + """Normalizes the size parameter. Called once at initialization. - mode = "default" + Returns the size in a canonical form: + - ``int`` — resize the shorter edge to this value (aspect-ratio preserving) + - ``None`` — use ``max_size`` only (resize so longer edge equals ``max_size``) + - ``(h, w)`` tuple/list — resize to the exact target dimensions + """ if isinstance(size, (tuple, list)) and len(size) == 1: size = size[0] + return size + + @classmethod + def calculate_target_size_dynamic_mode( + cls, + orig_size: Sequence[int], + size: Optional[int | Sequence[int]], + max_size: Optional[int], + ): + """Computes the output ``(out_h, out_w)`` compatible with ``torchvision.v2.Resize``. - if isinstance(size, int): - # If size is an int, smaller edge of the image will be matched to this number. - # If size is an int: if the longer edge of the image is greater than max_size - # after being resized according to size, size will be overruled so that the - # longer edge is equal to max_size. As a result, the smaller edge may be shorter - # than size. - mode = "resize_shorter" + Called per resize invocation with the actual input shape. - return ((size, size), mode) + Note: This method needs to be called only when in Dynamic Mode + Unfortunately, both method are needed because of graph creation struggles with proper + translation of class methods calls + """ + orig_h = orig_size[0] + orig_w = orig_size[1] - if size is None: - mode = "not_larger" - return ((max_size, max_size), mode) + if isinstance(size, (tuple, list)): + # Exact target dimensions — return directly + return size[0], size[1] - return size, mode + if size is None: + # Only max_size given: resize so the longer edge equals max_size + if orig_h >= orig_w: + return max_size, int(max_size * orig_w / orig_h) + else: + return int(max_size * orig_h / orig_w), max_size + + # size is int: resize the shorter edge to size, maintaining aspect ratio + s = size + if orig_h <= orig_w: + # height is the shorter (or equal) edge + out_h = s + out_w = int(s * orig_w / orig_h) + if max_size is not None and out_w > max_size: + out_h = int(max_size * out_h / out_w) + out_w = max_size + else: + # width is the shorter edge + out_h = int(s * orig_h / orig_w) + out_w = s + if max_size is not None and out_h > max_size: + out_w = int(max_size * out_w / out_h) + out_h = max_size + + return out_h, out_w @classmethod - def calculate_target_size( - cls, orig_size: Sequence[int], effective_size: Sequence[int], max_size: int, no_size: bool + def calculate_target_size_pipeline_mode( + cls, + orig_size: Sequence[int], + size: Optional[int | Sequence[int]], + max_size: Optional[int], ): + """Computes the output ``(out_h, out_w)`` compatible with ``torchvision.v2.Resize``. + + Called per resize invocation with the actual input shape. + + Note: This method needs to be called only when in Pipeline Mode + """ orig_h = orig_size[0] orig_w = orig_size[1] - target_h = effective_size[0] - target_w = effective_size[1] + if isinstance(size, (tuple, list)): + # Exact target dimensions — return directly + return size[0], size[1] - # If size is None, then effective_size is max_size - if no_size: - if orig_h > orig_w: - target_w = (max_size * orig_w) / orig_h + if size is None: + # Only max_size given: resize so the longer edge equals max_size + if orig_h >= orig_w: + return max_size, fn.cast( + dali.math.floor(max_size * orig_w / orig_h), dtype=dali.types.INT32 + ) else: - target_h = (max_size * orig_h) / orig_w - - return target_h, target_w + return ( + fn.cast(dali.math.floor(max_size * orig_h / orig_w), dtype=dali.types.INT32), + max_size, + ) + + # size is int: resize the shorter edge to size, maintaining aspect ratio + s = size + if orig_h <= orig_w: + # height is the shorter (or equal) edge + out_h = s + out_w = fn.cast(dali.math.floor(s * orig_w / orig_h), dtype=dali.types.INT32) + if max_size is not None and out_w > max_size: + out_h = fn.cast(dali.math.floor(max_size * out_h / out_w), dtype=dali.types.INT32) + out_w = max_size + else: + # width is the shorter edge + out_h = fn.cast(dali.math.floor(s * orig_h / orig_w), dtype=dali.types.INT32) + out_w = s + if max_size is not None and out_h > max_size: + out_w = fn.cast(dali.math.floor(max_size * out_w / out_h), dtype=dali.types.INT32) + out_h = max_size + + return out_h, out_w def __init__( self, @@ -183,8 +264,12 @@ def __init__( self.size = size self.max_size = max_size + + if interpolation in Resize.not_supported_interpolation_modes: + raise NotImplementedError(f"Interpolation mode: {interpolation} is not supported") + self.interpolation = Resize.interpolation_modes[interpolation] - self.effective_size, self.mode = Resize.infer_effective_size(size, max_size) + self.size_normalized = Resize.infer_effective_size(size, max_size) self.antialias = antialias def _kernel(self, data_input): @@ -193,25 +278,14 @@ def _kernel(self, data_input): with ``torchvision.transforms.Resize`` documentation and applies DALI operator on the ``data_input``. """ - input_height, input_width, data_input = data_input - target_h, target_w = Resize.calculate_target_size( - orig_size=(input_height, input_width), - effective_size=self.effective_size, - max_size=self.max_size, - no_size=self.size is None, - ) + in_h, in_w, data_input = data_input - # Shorter edge limited by max size - if self.mode == "resize_shorter": - return fn.resize( - data_input, - device=self.device, - resize_shorter=target_h, - max_size=self.max_size, - antialias=self.antialias, - interp_type=self.interpolation, - ) + target_h, target_w = Resize.calculate_target_size_pipeline_mode( + (in_h, in_w), + self.size_normalized, + self.max_size, + ) return fn.resize( data_input, @@ -220,7 +294,6 @@ def _kernel(self, data_input): fn.cast(target_h, dtype=dali.types.FLOAT), fn.cast(target_w, dtype=dali.types.FLOAT), ), - mode=self.mode, - antialias=self.antialias, interp_type=self.interpolation, + antialias=self.antialias, ) diff --git a/dali/test/python/torchvision/test_tv_compose.py b/dali/test/python/torchvision/test_tv_compose.py index 9c57684975c..cc2792dcd59 100644 --- a/dali/test/python/torchvision/test_tv_compose.py +++ b/dali/test/python/torchvision/test_tv_compose.py @@ -21,12 +21,12 @@ Resize, ) +from nose2.tools import params from nose_utils import assert_raises import numpy as np from PIL import Image -import torchvision.transforms as tv +import torchvision.transforms.v2 as tv import torch -import torchvision.transforms.v2 as transforms def read_filepath(path): @@ -51,7 +51,7 @@ def make_test_tensor(shape=(5, 10, 10, 1)): def test_compose_tensor(): - test_tensor = make_test_tensor(shape=(5, 5, 5, 3)) + test_tensor = make_test_tensor(shape=(5, 3, 5, 5)) dali_pipeline = Compose([RandomHorizontalFlip(p=1.0)], batch_size=test_tensor.shape[0]) dali_out = dali_pipeline(test_tensor) tv_out = tv.RandomHorizontalFlip(p=1.0)(test_tensor) @@ -61,7 +61,7 @@ def test_compose_tensor(): def test_compose_multi_tensor(): - test_tensor = make_test_tensor(shape=(5, 5, 5, 3)) + test_tensor = make_test_tensor(shape=(5, 3, 5, 5)) dali_pipeline = Compose( [Resize(size=(15, 15)), RandomHorizontalFlip(p=1.0), RandomVerticalFlip(p=1.0)], batch_size=test_tensor.shape[0], @@ -73,11 +73,12 @@ def test_compose_multi_tensor(): tv_out = tv_pipeline(test_tensor) assert isinstance(dali_out, torch.Tensor) - assert torch.equal(dali_out, tv_out) + # All close, because there are pixel differences due to resize + assert torch.allclose(dali_out, tv_out, rtol=0, atol=1), f"Should be {tv_out} is {dali_out}" def test_compose_invalid_batch_tensor(): - test_tensor = make_test_tensor(shape=(5, 5, 5, 1)) + test_tensor = make_test_tensor(shape=(5, 1, 5, 5)) with assert_raises(RuntimeError): dali_pipeline = Compose([RandomHorizontalFlip(p=1.0)], batch_size=1) _ = dali_pipeline(test_tensor) @@ -93,8 +94,8 @@ def test_compose_images(): assert isinstance(out_dali_img, Image.Image) - tensor_dali_tv = transforms.functional.pil_to_tensor(out_dali_img) - tensor_tv = transforms.functional.pil_to_tensor(tv_transform(img)) + tensor_dali_tv = tv.functional.pil_to_tensor(out_dali_img) + tensor_tv = tv.functional.pil_to_tensor(tv_transform(img)) assert tensor_dali_tv.shape == tensor_tv.shape assert torch.equal(tensor_dali_tv, tensor_tv) @@ -110,8 +111,8 @@ def test_compose_images_multi(): assert isinstance(out_dali_img, Image.Image) - tensor_dali_tv = transforms.functional.pil_to_tensor(out_dali_img) - tensor_tv = transforms.functional.pil_to_tensor(tv_transform(img)) + tensor_dali_tv = tv.functional.pil_to_tensor(out_dali_img) + tensor_tv = tv.functional.pil_to_tensor(tv_transform(img)) assert tensor_dali_tv.shape == tensor_tv.shape assert torch.equal(tensor_dali_tv, tensor_tv) @@ -123,5 +124,80 @@ def test_compose_invalid_type_images(): for fn in test_files: img = Image.open(fn) with assert_raises(TypeError): - out_dali_img = dali_transform([img, img, img]) - assert isinstance(out_dali_img, Image.Image) + _ = dali_transform([img, img, img]) + + +def _make_pil_image(mode, h=50, w=60, seed=42): + rng = np.random.default_rng(seed) + if mode == "L": + data = rng.integers(0, 256, (h, w), dtype=np.uint8) + elif mode == "RGB": + data = rng.integers(0, 256, (h, w, 3), dtype=np.uint8) + elif mode == "RGBA": + data = rng.integers(0, 256, (h, w, 4), dtype=np.uint8) + else: + raise ValueError(f"Unsupported mode: {mode}") + return Image.fromarray(data, mode=mode) + + +@params("RGB", "L", "RGBA") +def test_compose_pil_mode_flip(mode): + """Horizontal flip must produce a pixel-exact match with torchvision for all PIL modes.""" + img = _make_pil_image(mode) + dali_transform = Compose([RandomHorizontalFlip(p=1.0)]) + tv_transform = tv.Compose([tv.RandomHorizontalFlip(p=1.0)]) + + out_dali = dali_transform(img) + out_tv = tv_transform(img) + + assert isinstance(out_dali, Image.Image) + assert out_dali.mode == mode, f"Mode changed: expected {mode}, got {out_dali.mode}" + assert torch.equal( + tv.functional.pil_to_tensor(out_dali), + tv.functional.pil_to_tensor(out_tv), + ), f"Pixel mismatch for mode {mode}" + + +@params("RGB", "L", "RGBA") +def test_compose_pil_mode_resize(mode): + """Resize must produce the correct output shape and preserve PIL mode.""" + img = _make_pil_image(mode) + target = (30, 40) + dali_transform = Compose([Resize(size=target)]) + tv_transform = tv.Compose([tv.Resize(size=target)]) + + out_dali = dali_transform(img) + out_tv = tv_transform(img) + + assert isinstance(out_dali, Image.Image) + assert out_dali.mode == mode, f"Mode changed: expected {mode}, got {out_dali.mode}" + # PIL size is (w, h); compare as (h, w) to match the target convention + assert ( + out_dali.size == out_tv.size + ), f"Size mismatch for mode {mode}: {out_dali.size} != {out_tv.size}" + + +@params("RGB", "L", "RGBA") +def test_compose_pil_mode_multi_op(mode): + """Chained flip+resize must preserve mode and match torchvision output shape.""" + img = _make_pil_image(mode) + dali_transform = Compose([Resize(size=(30, 40)), RandomHorizontalFlip(p=1.0)]) + tv_transform = tv.Compose([tv.Resize(size=(30, 40)), tv.RandomHorizontalFlip(p=1.0)]) + + out_dali = dali_transform(img) + out_tv = tv_transform(img) + + assert isinstance(out_dali, Image.Image) + assert out_dali.mode == mode, f"Mode changed: expected {mode}, got {out_dali.mode}" + assert ( + out_dali.size == out_tv.size + ), f"Size mismatch for mode {mode}: {out_dali.size} != {out_tv.size}" + + +@params("RGB", "L", "RGBA") +def test_compose_pil_invalid_input_type_raises(mode): + """Passing a list instead of a PIL Image must raise TypeError regardless of mode.""" + img = _make_pil_image(mode) + dali_transform = Compose([RandomHorizontalFlip(p=1.0)]) + with assert_raises(TypeError): + _ = dali_transform([img, img]) diff --git a/dali/test/python/torchvision/test_tv_flips.py b/dali/test/python/torchvision/test_tv_flips.py index 39bc158b3f7..67869169989 100644 --- a/dali/test/python/torchvision/test_tv_flips.py +++ b/dali/test/python/torchvision/test_tv_flips.py @@ -20,7 +20,7 @@ from nvidia.dali.experimental.torchvision.v2.functional import horizontal_flip, vertical_flip -def make_test_tensor(shape=(1, 10, 10, 3)): +def make_test_tensor(shape=(1, 3, 10, 10)): total = 1 for s in shape: total *= s @@ -58,7 +58,7 @@ def test_vertical_random_flip_probability(device): def test_flip_preserves_shape(): - img = make_test_tensor((1, 15, 20, 3)) + img = make_test_tensor((1, 3, 15, 20)) hflip_pipeline = Compose([RandomHorizontalFlip(p=1.0)]) hflip_fn = horizontal_flip(img).cpu() hflip = hflip_pipeline(img) diff --git a/dali/test/python/torchvision/test_tv_resize.py b/dali/test/python/torchvision/test_tv_resize.py index d2ae2ae23da..cadaa331090 100644 --- a/dali/test/python/torchvision/test_tv_resize.py +++ b/dali/test/python/torchvision/test_tv_resize.py @@ -27,10 +27,6 @@ import nvidia.dali.experimental.torchvision.v2.functional as fn_dali -def read_file(path): - return np.fromfile(path, dtype=np.uint8) - - def read_filepath(path): return np.frombuffer(path.encode(), dtype=np.int8) @@ -107,13 +103,17 @@ def _internal_loop( out_fn = transforms.functional.pil_to_tensor(out_fn) out_dali_fn = transforms.functional.pil_to_tensor(out_dali_fn) - assert torch.allclose( - torch.tensor(out_tv.shape[1:3]), torch.tensor(out_dali_tv.shape[1:3]), rtol=0, atol=1 + assert ( + out_tv.shape[1:3] == out_dali_tv.shape[1:3] ), f"Should be:{out_tv.shape} is:{out_dali_tv.shape}" - assert torch.allclose( - torch.tensor(out_fn.shape[1:3]), torch.tensor(out_dali_fn.shape[1:3]), rtol=0, atol=1 + assert ( + out_fn.shape[1:3] == out_dali_fn.shape[1:3] ), f"Should be:{out_fn.shape} is:{out_dali_fn.shape}" + # TODO: + # assert torch.allclose(out_tv, out_dali_tv, rtol=1, atol=1) + # assert torch.allclose(out_fn, out_dali_fn, rtol=1, atol=1) + def loop_images_test_no_build( t: transforms.Resize, @@ -126,10 +126,12 @@ def loop_images_test_no_build( for fn in test_files: img = Image.open(fn) _internal_loop(img, t, td, resize, max_size, interpolation, antialias) - # assert torch.equal(out_tv, out_dali_tv) -def build_tensors(max_size: int = 512, channels: int = 3): +def build_tensors(max_size: int = 512, channels: int = 3, seed=12345): + + torch.manual_seed(seed) + h = torch.randint(10, max_size, (1,)).item() w = torch.randint(10, max_size, (1,)).item() tensors = [ @@ -175,13 +177,13 @@ def loop_images_test( loop_images_test_no_build(t, td, resize, max_size, interpolation, antialias) -@cartesian_params((512, 2048, ([512, 512]), ([2048, 2048])), ("cpu", "gpu")) +@cartesian_params((512, 1125, 2048, ([512, 512]), ([2048, 2048])), ("cpu", "gpu")) def test_resize_sizes_images(resize, device): # Resize with single int (preserve aspect ratio) loop_images_test(resize=resize, device=device) -@cartesian_params((512, 2048, ([512, 512]), ([2048, 2048])), ("cpu", "gpu")) +@cartesian_params((512, 1125, 2048, ([512, 512]), ([2048, 2048])), ("cpu", "gpu")) def test_resize_sizes_tensors(resize, device): # Resize with single int (preserve aspect ratio) loop_tensors_test(resize=resize, device=device) @@ -235,16 +237,45 @@ def test_resize_max_sizes(resize, max_size): loop_images_test(resize=resize, max_size=max_size) -@params( - ([512, 512], transforms.InterpolationMode.NEAREST), - (1024, transforms.InterpolationMode.NEAREST_EXACT), - ([256, 256], transforms.InterpolationMode.BILINEAR), - (640, transforms.InterpolationMode.BICUBIC), +@cartesian_params( + ( + 640, + 768, + 1024, + ([512, 512]), + ([256, 256]), + ), + ( + transforms.InterpolationMode.NEAREST, + transforms.InterpolationMode.NEAREST_EXACT, + transforms.InterpolationMode.BILINEAR, + transforms.InterpolationMode.BICUBIC, + ), + ("cpu", "gpu"), ) -def test_resize_interpoation(resize, interpolation): - loop_images_test(resize=resize, interpolation=interpolation) +def test_resize_interpolation(resize, interpolation, device): + if interpolation == transforms.InterpolationMode.NEAREST_EXACT: + with assert_raises(NotImplementedError): + loop_images_test(resize=resize, interpolation=interpolation, device=device) + else: + loop_images_test(resize=resize, interpolation=interpolation, device=device) + + +@cartesian_params((512, 768, 2048, ([512, 512]), ([2048, 2048])), (True, False), ("cpu", "gpu")) +def test_resize_antialiasing(resize, antialiasing, device): + loop_images_test(resize=resize, antialias=antialiasing, device=device) -@params((512, True), (2048, True), ([512, 512], True), ([2048, 2048], True)) -def test_resize_antialiasing(resize, antialiasing): - loop_images_test(resize=resize, antialias=antialiasing) +@cartesian_params((8192, 8193, 10243), ("cpu", "gpu")) +def test_large_sizes_images(resize, device): + loop_images_test(resize=resize, device=device) + + +""" +These tests are too heavy they would cause timeouts + +@cartesian_params((8192, 8193, 10243), ("cpu", "gpu")) +def test_large_sizes_tensors(resize, device): + # Resize with single int (preserve aspect ratio) + loop_tensors_test(resize=resize, device=device) +""" From fbd815f2201c908fe0aad317c296a6618739c268 Mon Sep 17 00:00:00 2001 From: Marek Dabek Date: Mon, 9 Mar 2026 12:28:38 +0100 Subject: [PATCH 6/7] Interpolation verification - review comment Signed-off-by: Marek Dabek --- .../dali/experimental/torchvision/v2/compose.py | 8 ++++++-- .../torchvision/v2/functional/flips.py | 7 +++++-- .../dali/experimental/torchvision/v2/resize.py | 15 +++++++++------ 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py b/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py index 170ab0754f5..f45784bb2a8 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py @@ -101,8 +101,12 @@ class PipelineWithLayout(ABC): """ def _cuda_run(self, data_input): - device_id = data_input.device.index if isinstance(data_input, torch.Tensor) else 0 + if isinstance(data_input, torch.Tensor) and data_input.is_cuda: + device_id = data_input.device_index + else: + device_id = torch.cuda.current_device() stream = torch.cuda.Stream(device=device_id) + with torch.cuda.stream(stream): output = self.pipe.run(stream, input_data=data_input) @@ -204,7 +208,7 @@ def _convert_tensor_to_image(self, in_tensor: torch.Tensor): channels = self.get_channel_reverse_idx() # TODO: consider when to convert to PIL.Image - e.g. if it make sense for channels < 3 - # There is noi certain method to determine if the tensor is HW, HWC, or NHWC. + # There is no certain method to determine if the tensor is HW, HWC, or NHWC. # The method below checks if tensor's shape is HW or ...HWC with a single channel if len(in_tensor.shape) == 2 or ( len(in_tensor.shape) >= 3 and in_tensor.shape[channels] == 1 diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py index 48e9760e5e5..bb6763ff250 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/functional/flips.py @@ -15,11 +15,14 @@ import nvidia.dali.experimental.dynamic as ndd +import torch +from PIL import Image + from ..operator import adjust_input # noqa: E402 @adjust_input -def horizontal_flip(inpt: ndd.Tensor | ndd.Batch) -> ndd.Tensor | ndd.Batch: +def horizontal_flip(inpt: Image.Image | torch.Tensor) -> Image.Image | torch.Tensor: """ Horizontally flips the given tensor. Refer to `HorizontalFlip` for more details. @@ -28,7 +31,7 @@ def horizontal_flip(inpt: ndd.Tensor | ndd.Batch) -> ndd.Tensor | ndd.Batch: @adjust_input -def vertical_flip(inpt: ndd.Tensor | ndd.Batch) -> ndd.Tensor | ndd.Batch: +def vertical_flip(inpt: Image.Image | torch.Tensor) -> Image.Image | torch.Tensor: """ Vertically flips the given tensor. Refer to `VerticalFlip` for more details. diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py b/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py index 89cba766d85..f9f4351a27d 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py @@ -66,8 +66,11 @@ def verify(cls, *, size, max_size, interpolation, **_): "Invalid combination: size must be int, None, or sequence of two ints. " "max_size only applies when size is int or None." ) - if size is None and max_size is None: - raise ValueError("Must provide max_size if size is None.") + if size is None and max_size is None and max_size > 0: + raise ValueError( + f"Must provide max_size if size is None and max_size must be > 0 \ + got {max_size}" + ) if size is not None and max_size is not None and np.min(size) > max_size: raise ValueError("max_size should not be smaller than the actual size") if isinstance(size, (tuple, list)) and len(size) == 2 and max_size is not None: @@ -75,6 +78,10 @@ def verify(cls, *, size, max_size, interpolation, **_): "max_size should only be passed if size specifies the length of the smaller \ edge, i.e. size should be an int" ) + + if interpolation in Resize.not_supported_interpolation_modes: + raise NotImplementedError(f"Interpolation mode: {interpolation} is not supported") + if interpolation not in Resize.interpolation_modes.keys(): raise ValueError(f"Interpolation {type(interpolation)} is not supported") @@ -264,10 +271,6 @@ def __init__( self.size = size self.max_size = max_size - - if interpolation in Resize.not_supported_interpolation_modes: - raise NotImplementedError(f"Interpolation mode: {interpolation} is not supported") - self.interpolation = Resize.interpolation_modes[interpolation] self.size_normalized = Resize.infer_effective_size(size, max_size) self.antialias = antialias From e0e98fd0e60965231ac22ce745c39395c83613e6 Mon Sep 17 00:00:00 2001 From: Marek Dabek Date: Mon, 9 Mar 2026 13:50:16 +0100 Subject: [PATCH 7/7] Fixed max_size check Signed-off-by: Marek Dabek --- .../nvidia/dali/experimental/torchvision/v2/compose.py | 2 +- .../nvidia/dali/experimental/torchvision/v2/resize.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py b/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py index f45784bb2a8..d3140a29c7a 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/compose.py @@ -102,7 +102,7 @@ class PipelineWithLayout(ABC): def _cuda_run(self, data_input): if isinstance(data_input, torch.Tensor) and data_input.is_cuda: - device_id = data_input.device_index + device_id = data_input.device.index else: device_id = torch.cuda.current_device() stream = torch.cuda.Stream(device=device_id) diff --git a/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py b/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py index f9f4351a27d..a8450734805 100644 --- a/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py +++ b/dali/python/nvidia/dali/experimental/torchvision/v2/resize.py @@ -66,13 +66,12 @@ def verify(cls, *, size, max_size, interpolation, **_): "Invalid combination: size must be int, None, or sequence of two ints. " "max_size only applies when size is int or None." ) - if size is None and max_size is None and max_size > 0: - raise ValueError( - f"Must provide max_size if size is None and max_size must be > 0 \ - got {max_size}" - ) + if size is None and max_size is None: + raise ValueError("Must provide max_size if size is None.") if size is not None and max_size is not None and np.min(size) > max_size: raise ValueError("max_size should not be smaller than the actual size") + if max_size is not None and np.min(max_size) < 0: + raise ValueError(f"max_size must not be smaller than 0, got{max_size}") if isinstance(size, (tuple, list)) and len(size) == 2 and max_size is not None: raise ValueError( "max_size should only be passed if size specifies the length of the smaller \