From 0543926312b42a4989ebd4bc88918de6ba3d7224 Mon Sep 17 00:00:00 2001 From: Ishan Dutta Date: Sat, 18 Nov 2023 01:06:18 +0530 Subject: [PATCH] :memo: [array] Add examples for EnsureType and CastToType Signed-off-by: Ishan Dutta --- monai/transforms/utility/array.py | 35 ++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 9aad12ef90..caf02d7b00 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -333,6 +333,23 @@ class CastToType(Transform): """ Cast the Numpy data to specified numpy data type, or cast the PyTorch Tensor to specified PyTorch data type. + + Example: + >>> import numpy as np + >>> import torch + >>> transform = CastToType(dtype=np.float32) + + >>> # Example with a numpy array + >>> img_np = np.array([0, 127, 255], dtype=np.uint8) + >>> img_np_casted = transform(img_np) + >>> img_np_casted + array([ 0. , 127. , 255. ], dtype=float32) + + >>> # Example with a PyTorch tensor + >>> img_tensor = torch.tensor([0, 127, 255], dtype=torch.uint8) + >>> img_tensor_casted = transform(img_tensor) + >>> img_tensor_casted + tensor([ 0., 127., 255.]) # dtype is float32 """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -413,10 +430,26 @@ class EnsureType(Transform): dtype: target data content type to convert, for example: np.float32, torch.float, etc. device: for Tensor data type, specify the target device. wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. - E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. track_meta: if `True` convert to ``MetaTensor``, otherwise to Pytorch ``Tensor``, if ``None`` behave according to return value of py:func:`monai.data.meta_obj.get_track_meta`. + Example with wrap_sequence=True: + >>> import numpy as np + >>> import torch + >>> transform = EnsureType(data_type="tensor", wrap_sequence=True) + >>> # Converting a list to a tensor + >>> data_list = [1, 2., 3] + >>> tensor_data = transform(data_list) + >>> tensor_data + tensor([1., 2., 3.]) # All elements have dtype float32 + + Example with wrap_sequence=False: + >>> transform = EnsureType(data_type="tensor", wrap_sequence=False) + >>> # Converting each element in a list to individual tensors + >>> data_list = [1, 2, 3] + >>> tensors_list = transform(data_list) + >>> tensors_list + [tensor(1), tensor(2.), tensor(3)] # Only second element is float32 rest are int64 """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]