Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,19 +354,17 @@ class EnsureType(Transform):

Args:
data_type: target data type to convert, should be "tensor" or "numpy".
device: for Tensor data type, specify the target device.

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, data_type: str = "tensor") -> None:
data_type = data_type.lower()
if data_type not in ("tensor", "numpy"):
raise ValueError("`data type` must be 'tensor' or 'numpy'.")

self.data_type = data_type
def __init__(self, data_type: str = "tensor", device: Optional[torch.device] = None) -> None:
self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"})
self.device = device

def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor:
def __call__(self, data: NdarrayOrTensor):
"""
Args:
data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
Expand All @@ -375,7 +373,7 @@ def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor:
if applicable.

"""
return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data) # type: ignore
return convert_to_tensor(data, device=self.device) if self.data_type == "tensor" else convert_to_numpy(data)


class ToNumpy(Transform):
Expand Down
11 changes: 9 additions & 2 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,16 +486,23 @@ class EnsureTyped(MapTransform, InvertibleTransform):

backend = EnsureType.backend

def __init__(self, keys: KeysCollection, data_type: str = "tensor", allow_missing_keys: bool = False) -> None:
def __init__(
self,
keys: KeysCollection,
data_type: str = "tensor",
device: Optional[torch.device] = None,
allow_missing_keys: bool = False,
) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
data_type: target data type to convert, should be "tensor" or "numpy".
device: for Tensor data type, specify the target device.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.converter = EnsureType(data_type=data_type)
self.converter = EnsureType(data_type=data_type, device=device)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ensure_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_single_input(self):
test_datas.append(test_datas[-1].cuda())
for test_data in test_datas:
for dtype in ("tensor", "numpy"):
result = EnsureType(data_type=dtype)(test_data)
result = EnsureType(data_type=dtype, device="cpu")(test_data)
self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray))
if isinstance(test_data, bool):
self.assertFalse(result)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ensure_typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_dict(self):
"extra": None,
}
for dtype in ("tensor", "numpy"):
result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"]
result = EnsureTyped(keys="data", data_type=dtype, device="cpu")({"data": test_data})["data"]
self.assertTrue(isinstance(result, dict))
self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray))
torch.testing.assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]))
Expand Down