diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 14300eeca0..b0ce187e38 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -171,7 +171,14 @@ def convert_data_type( Returns: modified data, orig_type, orig_device """ - orig_type = type(data) + orig_type: Any + if isinstance(data, torch.Tensor): + orig_type = torch.Tensor + elif isinstance(data, np.ndarray): + orig_type = np.ndarray + else: + orig_type = type(data) + orig_device = data.device if isinstance(data, torch.Tensor) else None output_type = output_type or orig_type