diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index beb210c645..23f72c5677 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -245,6 +245,8 @@ class ShiftIntensityd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ShiftIntensity`. """ + backend = ShiftIntensity.backend + def __init__( self, keys: KeysCollection, @@ -283,7 +285,7 @@ def __init__( self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) self.shifter = ShiftIntensity(offset) - def __call__(self, data) -> Dict[Hashable, np.ndarray]: + def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, factor_key, meta_key, meta_key_postfix in self.key_iterator( d, self.factor_key, self.meta_keys, self.meta_key_postfix @@ -300,6 +302,8 @@ class RandShiftIntensityd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandShiftIntensity`. """ + backend = ShiftIntensity.backend + def __init__( self, keys: KeysCollection, @@ -355,7 +359,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) super().randomize(None) - def __call__(self, data) -> Dict[Hashable, np.ndarray]: + def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize() if not self._do_transform: diff --git a/tests/test_rand_shift_intensityd.py b/tests/test_rand_shift_intensityd.py index 71cfd8fc50..6766236146 100644 --- a/tests/test_rand_shift_intensityd.py +++ b/tests/test_rand_shift_intensityd.py @@ -14,18 +14,19 @@ import numpy as np from monai.transforms import IntensityStatsd, RandShiftIntensityd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRandShiftIntensityd(NumpyImageTestCase2D): def test_value(self): - key = "img" - shifter = RandShiftIntensityd(keys=[key], offsets=1.0, prob=1.0) - shifter.set_random_state(seed=0) - result = shifter({key: self.imt}) - np.random.seed(0) - expected = self.imt + np.random.uniform(low=-1.0, high=1.0) - np.testing.assert_allclose(result[key], expected) + for p in TEST_NDARRAYS: + key = "img" + shifter = RandShiftIntensityd(keys=[key], offsets=1.0, prob=1.0) + shifter.set_random_state(seed=0) + result = shifter({key: p(self.imt)}) + np.random.seed(0) + expected = self.imt + np.random.uniform(low=-1.0, high=1.0) + assert_allclose(result[key], expected) def test_factor(self): key = "img" diff --git a/tests/test_shift_intensityd.py b/tests/test_shift_intensityd.py index 71cfffc9c5..0396857781 100644 --- a/tests/test_shift_intensityd.py +++ b/tests/test_shift_intensityd.py @@ -14,16 +14,17 @@ import numpy as np from monai.transforms import IntensityStatsd, ShiftIntensityd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestShiftIntensityd(NumpyImageTestCase2D): def test_value(self): key = "img" - shifter = ShiftIntensityd(keys=[key], offset=1.0) - result = shifter({key: self.imt}) - expected = self.imt + 1.0 - np.testing.assert_allclose(result[key], expected) + for p in TEST_NDARRAYS: + shifter = ShiftIntensityd(keys=[key], offset=1.0) + result = shifter({key: p(self.imt)}) + expected = self.imt + 1.0 + assert_allclose(result[key], expected) def test_factor(self): key = "img"