diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 4203724a7d..5267af4048 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -500,6 +500,7 @@ get_extreme_points, get_largest_connected_component_mask, get_number_image_type_conversions, + get_transform_backends, img_bounds, in_bounds, is_empty, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 46c512c96c..b36c7adf96 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -131,7 +131,7 @@ class RandRicianNoise(RandomizableTransform): uniformly from 0 to std. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, @@ -197,7 +197,7 @@ class ShiftIntensity(Transform): offset: offset value to shift the intensity of image. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__(self, offset: float) -> None: self.offset = offset @@ -219,7 +219,7 @@ class RandShiftIntensity(RandomizableTransform): Randomly shift intensity with randomly picked offset. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1) -> None: """ @@ -273,7 +273,7 @@ class StdShiftIntensity(Transform): dtype: output data type, defaults to float32. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, factor: float, nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32 @@ -318,7 +318,7 @@ class RandStdShiftIntensity(RandomizableTransform): by: ``v = v + factor * std(v)`` where the `factor` is randomly picked. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 3ef6413090..d56bca0d8d 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -291,7 +291,7 @@ class CastToType(Transform): specified PyTorch data type. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__(self, dtype=np.float32) -> None: """ diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e3e61b6c97..30aa5e7b99 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -38,6 +38,7 @@ min_version, optional_import, ) +from monai.utils.enums import TransformBackends from monai.utils.type_conversion import convert_data_type measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -81,6 +82,7 @@ "zero_margins", "equalize_hist", "get_number_image_type_conversions", + "get_transform_backends", "print_transform_backends", ] @@ -1158,22 +1160,17 @@ def _get_data(obj, key): return num_conversions -def print_transform_backends(): - """Prints a list of backends of all MONAI transforms.""" - - class Colours: - red = "91" - green = "92" - yellow = "93" - - def print_colour(t, colour): - print(f"\033[{colour}m{t}\033[00m") +def get_transform_backends(): + """Get the backends of all MONAI transforms. - tr_total = 0 - tr_t_or_np = 0 - tr_t = 0 - tr_np = 0 - tr_uncategorised = 0 + Returns: + Dictionary, where each key is a transform, and its + corresponding values are a boolean list, stating + whether that transform supports (1) `torch.Tensor`, + and (2) `np.ndarray` as input without needing to + convert. + """ + backends = {} unique_transforms = [] for n, obj in getmembers(monai.transforms): # skip aliases @@ -1194,21 +1191,54 @@ def print_colour(t, colour): "InverteD", ]: continue - tr_total += 1 - if obj.backend == ["torch", "numpy"]: - tr_t_or_np += 1 - print_colour(f"TorchOrNumpy: {n}", Colours.green) - elif obj.backend == ["torch"]: - tr_t += 1 - print_colour(f"Torch: {n}", Colours.green) - elif obj.backend == ["numpy"]: - tr_np += 1 - print_colour(f"Numpy: {n}", Colours.yellow) - else: - tr_uncategorised += 1 - print_colour(f"Uncategorised: {n}", Colours.red) - print("Total number of transforms:", tr_total) - print_colour(f"Number transforms allowing both torch and numpy: {tr_t_or_np}", Colours.green) - print_colour(f"Number of TorchTransform: {tr_t}", Colours.green) - print_colour(f"Number of NumpyTransform: {tr_np}", Colours.yellow) - print_colour(f"Number of uncategorised: {tr_uncategorised}", Colours.red) + + backends[n] = [ + TransformBackends.TORCH in obj.backend, + TransformBackends.NUMPY in obj.backend, + ] + return backends + + +def print_transform_backends(): + """Prints a list of backends of all MONAI transforms.""" + + class Colors: + none = "" + red = "91" + green = "92" + yellow = "93" + + def print_color(t, color): + print(f"\033[{color}m{t}\033[00m") + + def print_table_column(name, torch, numpy, color=Colors.none): + print_color("{:<50} {:<8} {:<8}".format(name, torch, numpy), color) + + backends = get_transform_backends() + n_total = len(backends) + n_t_or_np, n_t, n_np, n_uncategorized = 0, 0, 0, 0 + print_table_column("Transform", "Torch?", "Numpy?") + for k, v in backends.items(): + if all(v): + color = Colors.green + n_t_or_np += 1 + elif v[0]: + color = Colors.green + n_t += 1 + elif v[1]: + color = Colors.yellow + n_np += 1 + else: + color = Colors.red + n_uncategorized += 1 + print_table_column(k, *v, color) + + print("Total number of transforms:", n_total) + print_color(f"Number transforms allowing both torch and numpy: {n_t_or_np}", Colors.green) + print_color(f"Number of TorchTransform: {n_t}", Colors.green) + print_color(f"Number of NumpyTransform: {n_np}", Colors.yellow) + print_color(f"Number of uncategorised: {n_uncategorized}", Colors.red) + + +if __name__ == "__main__": + print_transform_backends() diff --git a/tests/test_print_transform_backends.py b/tests/test_print_transform_backends.py index 09828f0a27..4164687f01 100644 --- a/tests/test_print_transform_backends.py +++ b/tests/test_print_transform_backends.py @@ -11,13 +11,17 @@ import unittest -from monai.transforms.utils import print_transform_backends +from monai.transforms.utils import get_transform_backends, print_transform_backends class TestPrintTransformBackends(unittest.TestCase): def test_get_number_of_conversions(self): + tr_t_or_np, *_ = get_transform_backends() + self.assertGreater(len(tr_t_or_np), 0) print_transform_backends() if __name__ == "__main__": - unittest.main() + # unittest.main() + a = TestPrintTransformBackends() + a.test_get_number_of_conversions()