diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index d03028d198..a892b5b8fd 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1031,7 +1031,7 @@ def _get_array_data(self, img): img: a Nibabel image object loaded from an image file. """ - return np.asanyarray(img.dataobj) + return np.asanyarray(img.dataobj, order="C") class NumpyReader(ImageReader): diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 67f4109c86..86ce7e33fb 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -19,8 +19,7 @@ import numpy as np import torch -from monai.utils.enums import TraceKeys -from monai.utils.misc import first +from monai.utils import TraceKeys, first, is_immutable _TRACK_META = True @@ -107,13 +106,15 @@ def flatten_meta_objs(*args: Iterable): @staticmethod def copy_items(data): """returns a copy of the data. list and dict are shallow copied for efficiency purposes.""" + if is_immutable(data): + return data if isinstance(data, (list, dict, np.ndarray)): return data.copy() if isinstance(data, torch.Tensor): return data.detach().clone() return deepcopy(data) - def copy_meta_from(self, input_objs, copy_attr=True) -> None: + def copy_meta_from(self, input_objs, copy_attr=True, keys=None): """ Copy metadata from a `MetaObj` or an iterable of `MetaObj` instances. @@ -121,13 +122,19 @@ def copy_meta_from(self, input_objs, copy_attr=True) -> None: input_objs: list of `MetaObj` to copy data from. copy_attr: whether to copy each attribute with `MetaObj.copy_item`. note that if the attribute is a nested list or dict, only a shallow copy will be done. + keys: the keys of attributes to copy from the ``input_objs``. + If None, all keys from the input_objs will be copied. """ first_meta = input_objs if isinstance(input_objs, MetaObj) else first(input_objs, default=self) + if not hasattr(first_meta, "__dict__"): + return self first_meta = first_meta.__dict__ + keys = first_meta.keys() if keys is None else keys if not copy_attr: - self.__dict__ = first_meta.copy() # shallow copy for performance + self.__dict__ = {a: first_meta[a] for a in keys if a in first_meta} # shallow copy for performance else: - self.__dict__.update({a: MetaObj.copy_items(first_meta[a]) for a in first_meta}) + self.__dict__.update({a: MetaObj.copy_items(first_meta[a]) for a in keys if a in first_meta}) + return self @staticmethod def get_default_meta() -> dict: diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 22f9502708..d77fd782c2 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -503,15 +503,15 @@ def clone(self): @staticmethod def ensure_torch_and_prune_meta( - im: NdarrayTensor, meta: dict, simple_keys: bool = False, pattern: str | None = None, sep: str = "." + im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "." ): """ - Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary, + Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary, convert that to `torch.Tensor`, too. Remove any superfluous metadata. Args: im: Input image (`np.ndarray` or `torch.Tensor`) - meta: Metadata dictionary. + meta: Metadata dictionary. When it's None, the metadata is not tracked, this method returns a torch.Tensor. simple_keys: whether to keep only a simple subset of metadata keys. pattern: combined with `sep`, a regular expression used to match and prune keys in the metadata (nested dictionary), default to None, no key deletion. @@ -521,14 +521,17 @@ def ensure_torch_and_prune_meta( Returns: By default, a `MetaTensor` is returned. - However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned. + However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned. """ - img = convert_to_tensor(im) # potentially ascontiguousarray + img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray # if not tracking metadata, return `torch.Tensor` - if not get_track_meta() or meta is None: + if not isinstance(img, MetaTensor): return img + if meta is None: + meta = {} + # remove any superfluous metadata. if simple_keys: # ensure affine is of type `torch.Tensor` @@ -540,7 +543,14 @@ def ensure_torch_and_prune_meta( meta = monai.transforms.DeleteItemsd(keys=pattern, sep=sep, use_re=True)(meta) # return the `MetaTensor` - return MetaTensor(img, meta=meta) + if meta is None: + meta = {} + img.meta = meta + if MetaKeys.AFFINE in meta: + img.affine = meta[MetaKeys.AFFINE] # this uses the affine property setter + else: + img.affine = MetaTensor.get_default_affine() + return img def __repr__(self): """ diff --git a/monai/data/utils.py b/monai/data/utils.py index 5d6869334b..91358b2c63 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -46,6 +46,7 @@ ensure_tuple_size, fall_back_tuple, first, + get_equivalent_dtype, issequenceiterable, look_up_option, optional_import, @@ -924,6 +925,7 @@ def to_affine_nd(r: np.ndarray | int, affine: NdarrayTensor, dtype=np.float64) - an (r+1) x (r+1) matrix (tensor or ndarray depends on the input ``affine`` data type) """ + dtype = get_equivalent_dtype(dtype, np.ndarray) affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0] affine_np = affine_np.copy() if affine_np.ndim != 2: diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 6d9060723a..80a27b98b5 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -20,9 +20,11 @@ import torch from monai import transforms +from monai.data.meta_obj import MetaObj, get_track_meta from monai.data.meta_tensor import MetaTensor -from monai.transforms.transform import Transform -from monai.utils.enums import TraceKeys +from monai.data.utils import to_affine_nd +from monai.transforms.transform import LazyTransform, Transform +from monai.utils import LazyAttr, MetaKeys, TraceKeys, convert_to_dst_type, convert_to_numpy, convert_to_tensor __all__ = ["TraceableTransform", "InvertibleTransform"] @@ -72,76 +74,160 @@ def trace_key(key: Hashable = None): return f"{TraceKeys.KEY_SUFFIX}" return f"{key}{TraceKeys.KEY_SUFFIX}" - def get_transform_info( - self, data, key: Hashable = None, extra_info: dict | None = None, orig_size: tuple | None = None - ) -> dict: + @staticmethod + def transform_info_keys(): + """The keys to store necessary info of an applied transform.""" + return ( + TraceKeys.CLASS_NAME, + TraceKeys.ID, + TraceKeys.TRACING, + TraceKeys.LAZY_EVALUATION, + TraceKeys.DO_TRANSFORM, + ) + + def get_transform_info(self) -> dict: """ Return a dictionary with the relevant information pertaining to an applied transform. + """ + vals = ( + self.__class__.__name__, + id(self), + self.tracing, + self.lazy_evaluation if isinstance(self, LazyTransform) else False, + self._do_transform if hasattr(self, "_do_transform") else True, + ) + return dict(zip(self.transform_info_keys(), vals)) - Args: - data: input data. Can be dictionary or MetaTensor. We can use `shape` to - determine the original size of the object (unless that has been given - explicitly, see `orig_size`). - key: if data is a dictionary, data[key] will be modified. - extra_info: if desired, any extra information pertaining to the applied - transform can be stored in this dictionary. These are often needed for - computing the inverse transformation. - orig_size: sometimes during the inverse it is useful to know what the size - of the original image was, in which case it can be supplied here. + def push_transform(self, data, *args, **kwargs): + """ + Push to a stack of applied transforms of ``data``. - Returns: - Dictionary of data pertaining to the applied transformation. + Args: + data: dictionary of data or `MetaTensor`. + args: additional positional arguments to track_transform_meta. + kwargs: additional keyword arguments to track_transform_meta, + set ``replace=True`` (default False) to rewrite the last transform infor in + applied_operation/pending_operation based on ``self.get_transform_info()``. """ - info = {TraceKeys.CLASS_NAME: self.__class__.__name__, TraceKeys.ID: id(self)} - if orig_size is not None: - info[TraceKeys.ORIG_SIZE] = orig_size - elif isinstance(data, Mapping) and key in data and hasattr(data[key], "shape"): - info[TraceKeys.ORIG_SIZE] = data[key].shape[1:] - elif hasattr(data, "shape"): - info[TraceKeys.ORIG_SIZE] = data.shape[1:] - if extra_info is not None: - info[TraceKeys.EXTRA_INFO] = extra_info - # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) - if hasattr(self, "_do_transform"): # RandomizableTransform - info[TraceKeys.DO_TRANSFORM] = self._do_transform - return info - - def push_transform( - self, data, key: Hashable = None, extra_info: dict | None = None, orig_size: tuple | None = None - ) -> None: + transform_info = self.get_transform_info() + lazy_eval = transform_info.get(TraceKeys.LAZY_EVALUATION, False) + do_transform = transform_info.get(TraceKeys.DO_TRANSFORM, True) + kwargs = kwargs or {} + replace = kwargs.pop("replace", False) # whether to rewrite the most recently pushed transform info + if replace and get_track_meta() and isinstance(data, MetaTensor): + if not lazy_eval: + xform = self.pop_transform(data, check=False) if do_transform else {} + meta_obj = self.push_transform(data, orig_size=xform.get(TraceKeys.ORIG_SIZE), extra_info=xform) + return data.copy_meta_from(meta_obj) + if do_transform: + xform = data.pending_operations.pop() # type: ignore + xform.update(transform_info) + meta_obj = self.push_transform(data, transform_info=xform, lazy_evaluation=lazy_eval) + return data.copy_meta_from(meta_obj) + return data + kwargs["lazy_evaluation"] = lazy_eval + if "transform_info" in kwargs and isinstance(kwargs["transform_info"], dict): + kwargs["transform_info"].update(transform_info) + else: + kwargs["transform_info"] = transform_info + meta_obj = TraceableTransform.track_transform_meta(data, *args, **kwargs) + return data.copy_meta_from(meta_obj) if isinstance(data, MetaTensor) else data + + @classmethod + def track_transform_meta( + cls, + data, + key: Hashable = None, + sp_size=None, + affine=None, + extra_info: dict | None = None, + orig_size: tuple | None = None, + transform_info=None, + lazy_evaluation=False, + ): """ - Push to a stack of applied transforms. + Update a stack of applied/pending transforms metadata of ``data``. Args: data: dictionary of data or `MetaTensor`. key: if data is a dictionary, data[key] will be modified. + sp_size: the expected output spatial size when the transform is applied. + it can be tensor or numpy, but will be converted to a list of integers. + affine: the affine representation of the (spatial) transform in the image space. + When the transform is applied, meta_tensor.affine will be updated to ``meta_tensor.affine @ affine``. extra_info: if desired, any extra information pertaining to the applied transform can be stored in this dictionary. These are often needed for computing the inverse transformation. orig_size: sometimes during the inverse it is useful to know what the size of the original image was, in which case it can be supplied here. + transform_info: info from self.get_transform_info(). + lazy_evaluation: whether to push the transform to pending_operations or applied_operations. Returns: - None, but data has been updated to store the applied transformation. + + For backward compatibility, if ``data`` is a dictionary, it returns the dictionary with + updated ``data[key]``. Otherwise, this function returns a MetaObj with updated transform metadata. """ - if not self.tracing: - return - info = self.get_transform_info(data, key, extra_info, orig_size) + data_t = data[key] if key is not None else data # compatible with the dict data representation + out_obj = MetaObj() + # after deprecating metadict, we should always convert data_t to metatensor here + if isinstance(data_t, MetaTensor): + out_obj.copy_meta_from(data_t, keys=out_obj.__dict__.keys()) + + if not lazy_evaluation and affine is not None and isinstance(data_t, MetaTensor): + # not lazy evaluation, directly update the metatensor affine (don't push to the stack) + orig_affine = data_t.peek_pending_affine() + orig_affine = convert_to_dst_type(orig_affine, affine)[0] + affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=affine.dtype) + out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu")) + + if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)): + if isinstance(data, Mapping): + if not isinstance(data, dict): + data = dict(data) + data[key] = data_t.copy_meta_from(out_obj) if isinstance(data_t, MetaTensor) else data_t + return data + return out_obj # return with data_t as tensor if get_track_meta() is False + + info = transform_info + # track the current spatial shape + if orig_size is not None: + info[TraceKeys.ORIG_SIZE] = orig_size + elif isinstance(data_t, MetaTensor): + info[TraceKeys.ORIG_SIZE] = data_t.peek_pending_shape() + elif hasattr(data_t, "shape"): + info[TraceKeys.ORIG_SIZE] = data_t.shape[1:] + # include extra_info + if extra_info is not None: + info[TraceKeys.EXTRA_INFO] = extra_info - if isinstance(data, MetaTensor): - data.push_applied_operation(info) - elif isinstance(data, Mapping): - if key in data and isinstance(data[key], MetaTensor): - data[key].push_applied_operation(info) + # push the transform info to the applied_operation or pending_operation stack + if lazy_evaluation: + if sp_size is None: + if LazyAttr.SHAPE not in info: + warnings.warn("spatial size is None in push transform.") + else: + info[LazyAttr.SHAPE] = tuple(convert_to_numpy(sp_size, wrap_sequence=True).tolist()) + if affine is None: + if LazyAttr.AFFINE not in info: + warnings.warn("affine is None in push transform.") else: - # If this is the first, create list - if self.trace_key(key) not in data: - if not isinstance(data, dict): - data = dict(data) - data[self.trace_key(key)] = [] - data[self.trace_key(key)].append(info) + info[LazyAttr.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu")) + out_obj.push_pending_operation(info) else: - warnings.warn(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}. {info} not tracked.") + out_obj.push_applied_operation(info) + if isinstance(data, Mapping): + if not isinstance(data, dict): + data = dict(data) + if isinstance(data_t, MetaTensor): + data[key] = data_t.copy_meta_from(out_obj) + else: + x_k = TraceableTransform.trace_key(key) + if x_k not in data: + data[x_k] = [] # If this is the first, create list + data[x_k].append(info) + return data + return out_obj def check_transforms_match(self, transform: Mapping) -> None: """Check transforms are of same instance.""" diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 13aa753a55..44e46d4bdb 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -11,10 +11,12 @@ from __future__ import annotations +from typing import Any + +import numpy as np import torch from monai.data.meta_tensor import MetaTensor -from monai.data.utils import to_affine_nd from monai.transforms.lazy.utils import ( affine_from_pending, combine_transforms, @@ -22,40 +24,77 @@ kwargs_from_pending, resample, ) +from monai.utils import LazyAttr __all__ = ["apply_transforms"] -def apply_transforms(data: torch.Tensor | MetaTensor, pending: list | None = None): +def apply_transforms( + data: torch.Tensor | MetaTensor, + pending: list | None = None, + mode: str | int | None = None, + padding_mode: str | None = None, + dtype=np.float64, + align_corners: bool | None = None, +): """ This method applies pending transforms to `data` tensors. Args: data: A torch Tensor or a monai MetaTensor. pending: pending transforms. This must be set if data is a Tensor, but is optional if data is a MetaTensor. + mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers). + Interpolation mode to calculate output values. Defaults to None. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used + and the value represents the order of the spline interpolation. + See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to None. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `mode` is an integer, using numpy/cupy backends, this argument accepts + {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}. + See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html + dtype: data type for resampling computation. Defaults to ``float64``. + If ``None``, use the data type of input data`. + align_corners: Geometrically, we consider the pixels of the input as squares rather than points, when using + the PyTorch resampling backend. Defaults to ``None``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html """ if isinstance(data, MetaTensor) and pending is None: - pending = data.pending_operations + pending = data.pending_operations.copy() + data.clear_pending_operations() pending = [] if pending is None else pending if not pending: - return data + return data, [] cumulative_xform = affine_from_pending(pending[0]) cur_kwargs = kwargs_from_pending(pending[0]) + override_kwargs: dict[str, Any] = {} + if mode is not None: + override_kwargs[LazyAttr.INTERP_MODE] = mode + if padding_mode is not None: + override_kwargs[LazyAttr.PADDING_MODE] = padding_mode + if align_corners is not None: + override_kwargs[LazyAttr.ALIGN_CORNERS] = align_corners + override_kwargs[LazyAttr.DTYPE] = data.dtype if dtype is None else dtype for p in pending[1:]: new_kwargs = kwargs_from_pending(p) if not is_compatible_apply_kwargs(cur_kwargs, new_kwargs): # carry out an intermediate resample here due to incompatibility between arguments - data = resample(data, cumulative_xform, cur_kwargs) + _cur_kwargs = cur_kwargs.copy() + _cur_kwargs.update(override_kwargs) + sp_size = _cur_kwargs.pop(LazyAttr.SHAPE, None) + data = resample(data, cumulative_xform, sp_size, _cur_kwargs) next_matrix = affine_from_pending(p) cumulative_xform = combine_transforms(cumulative_xform, next_matrix) cur_kwargs.update(new_kwargs) - data = resample(data, cumulative_xform, cur_kwargs) + cur_kwargs.update(override_kwargs) + sp_size = cur_kwargs.pop(LazyAttr.SHAPE, None) + data = resample(data, cumulative_xform, sp_size, cur_kwargs) if isinstance(data, MetaTensor): - data.clear_pending_operations() - data.affine = data.affine @ to_affine_nd(3, cumulative_xform) for p in pending: data.push_applied_operation(p) diff --git a/monai/transforms/lazy/utils.py b/monai/transforms/lazy/utils.py index e03314d655..1672695ed2 100644 --- a/monai/transforms/lazy/utils.py +++ b/monai/transforms/lazy/utils.py @@ -105,21 +105,30 @@ def is_compatible_apply_kwargs(kwargs_1, kwargs_2): return True -def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = None): +def resample(data: torch.Tensor, matrix: NdarrayOrTensor, spatial_size, kwargs: dict | None = None): """ - This is a minimal implementation of resample that always uses Affine. + This is a minimal implementation of resample that always uses SpatialResample. + `kwargs` supports "lazy_dtype", "lazy_padding_mode", "lazy_interpolation_mode", "lazy_dtype", "lazy_align_corners". + + See Also: + :py:class:`monai.transforms.SpatialResample` """ if not Affine.is_affine_shaped(matrix): raise NotImplementedError("calling dense grid resample API not implemented") kwargs = {} if kwargs is None else kwargs init_kwargs = { - "spatial_size": kwargs.pop(LazyAttr.SHAPE, data.shape)[1:], "dtype": kwargs.pop(LazyAttr.DTYPE, data.dtype), + "align_corners": kwargs.pop(LazyAttr.ALIGN_CORNERS, None), } + img = convert_to_tensor(data=data, track_meta=monai.data.get_track_meta()) + init_affine = monai.data.to_affine_nd(len(matrix) - 1, img.affine) call_kwargs = { + "spatial_size": img.peek_pending_shape() if spatial_size is None else spatial_size, + "dst_affine": init_affine @ monai.utils.convert_to_dst_type(matrix, init_affine)[0], "mode": kwargs.pop(LazyAttr.INTERP_MODE, None), "padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None), } - resampler = monai.transforms.Affine(affine=matrix, image_only=True, **init_kwargs) + resampler = monai.transforms.SpatialResample(**init_kwargs) + # resampler.lazy_evaluation = False with resampler.trace_transform(False): # don't track this transform in `data` - return resampler(img=data, **call_kwargs) + return resampler(img=img, **call_kwargs) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 49daefcdda..601a5f10ae 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -70,6 +70,7 @@ first, get_seed, has_option, + is_immutable, is_module_ver_at_least, is_scalar, is_scalar_tensor, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index d1ac19f4b4..7a4aaaece7 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -310,6 +310,8 @@ class TraceKeys(StrEnum): DO_TRANSFORM: str = "do_transforms" KEY_SUFFIX: str = "_transforms" NONE: str = "none" + TRACING: str = "tracing" + LAZY_EVALUATION: str = "lazy_evaluation" class CommonKeys(StrEnum): @@ -623,3 +625,4 @@ class LazyAttr(StrEnum): PADDING_MODE = "lazy_padding_mode" INTERP_MODE = "lazy_interpolation_mode" DTYPE = "lazy_dtype" + ALIGN_CORNERS = "lazy_align_corners" diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 554cc1b278..f2086f22d5 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -36,6 +36,7 @@ "star_zip_with", "first", "issequenceiterable", + "is_immutable", "ensure_tuple", "ensure_tuple_size", "ensure_tuple_rep", @@ -116,6 +117,15 @@ def issequenceiterable(obj: Any) -> bool: return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)) +def is_immutable(obj: Any) -> bool: + """ + Determine if the object is an immutable object. + + see also https://github.com/python/cpython/blob/3.11/Lib/copy.py#L109 + """ + return isinstance(obj, (type(None), int, float, bool, complex, str, tuple, bytes, type, range, slice)) + + def ensure_tuple(vals: Any, wrap_array: bool = False) -> tuple[Any, ...]: """ Returns a tuple of `vals`. diff --git a/tests/test_apply.py b/tests/test_apply.py index 8974360381..cf74721267 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -32,7 +32,7 @@ def single_2d_transform_cases(): (torch.as_tensor(get_arange_img((32, 32))), [create_rotate(2, np.pi / 2)], (1, 32, 32)), ( torch.as_tensor(get_arange_img((16, 16))), - [{LazyAttr.AFFINE: create_rotate(2, np.pi / 2), LazyAttr.SHAPE: (1, 45, 45)}], + [{LazyAttr.AFFINE: create_rotate(2, np.pi / 2), LazyAttr.SHAPE: (45, 45)}], (1, 45, 45), ), ] @@ -51,6 +51,8 @@ def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape else: for p in pending_transforms: tensor_.push_pending_operation(p) + if not isinstance(p, dict): + return result, transforms = apply_transforms(tensor_) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_resample.py b/tests/test_resample.py index 3ebdd23e02..8b2ffea194 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -28,13 +28,13 @@ def rotate_90_2d(): return t -RESAMPLE_FUNCTION_CASES = [(get_arange_img((3, 3)), rotate_90_2d(), [[2, 5, 8], [1, 4, 7], [0, 3, 6]])] +RESAMPLE_FUNCTION_CASES = [(get_arange_img((3, 3)), rotate_90_2d(), [[0, 3, 6], [0, 3, 6], [0, 3, 6]])] class TestResampleFunction(unittest.TestCase): @parameterized.expand(RESAMPLE_FUNCTION_CASES) def test_resample_function_impl(self, img, matrix, expected): - out = resample(convert_to_tensor(img), matrix) + out = resample(convert_to_tensor(img), matrix, img.shape[1:]) assert_allclose(out[0], expected, type_test=False) diff --git a/tests/test_traceable_transform.py b/tests/test_traceable_transform.py index cf3da7139a..42906c84d2 100644 --- a/tests/test_traceable_transform.py +++ b/tests/test_traceable_transform.py @@ -30,6 +30,8 @@ class TestTraceable(unittest.TestCase): def test_default(self): expected_key = "_transforms" a = _TraceTest() + for x in a.transform_info_keys(): + self.assertTrue(x in a.get_transform_info()) self.assertEqual(a.trace_key(), expected_key) data = {"image": "test"}