diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index add47e27ca..7f94b50044 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -354,19 +354,17 @@ class EnsureType(Transform): Args: data_type: target data type to convert, should be "tensor" or "numpy". + device: for Tensor data type, specify the target device. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, data_type: str = "tensor") -> None: - data_type = data_type.lower() - if data_type not in ("tensor", "numpy"): - raise ValueError("`data type` must be 'tensor' or 'numpy'.") - - self.data_type = data_type + def __init__(self, data_type: str = "tensor", device: Optional[torch.device] = None) -> None: + self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"}) + self.device = device - def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, data: NdarrayOrTensor): """ Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. @@ -375,7 +373,7 @@ def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor: if applicable. """ - return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data) # type: ignore + return convert_to_tensor(data, device=self.device) if self.data_type == "tensor" else convert_to_numpy(data) class ToNumpy(Transform): diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index e9bcce93b0..2fdf20c1fe 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -486,16 +486,23 @@ class EnsureTyped(MapTransform, InvertibleTransform): backend = EnsureType.backend - def __init__(self, keys: KeysCollection, data_type: str = "tensor", allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + data_type: str = "tensor", + device: Optional[torch.device] = None, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` data_type: target data type to convert, should be "tensor" or "numpy". + device: for Tensor data type, specify the target device. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = EnsureType(data_type=data_type) + self.converter = EnsureType(data_type=data_type, device=device) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) diff --git a/tests/test_ensure_type.py b/tests/test_ensure_type.py index 8feb96ed37..86bc3db703 100644 --- a/tests/test_ensure_type.py +++ b/tests/test_ensure_type.py @@ -36,7 +36,7 @@ def test_single_input(self): test_datas.append(test_datas[-1].cuda()) for test_data in test_datas: for dtype in ("tensor", "numpy"): - result = EnsureType(data_type=dtype)(test_data) + result = EnsureType(data_type=dtype, device="cpu")(test_data) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) if isinstance(test_data, bool): self.assertFalse(result) diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index 96f482afc2..e4c72d37e2 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -75,7 +75,7 @@ def test_dict(self): "extra": None, } for dtype in ("tensor", "numpy"): - result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] + result = EnsureTyped(keys="data", data_type=dtype, device="cpu")({"data": test_data})["data"] self.assertTrue(isinstance(result, dict)) self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]))