Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@
get_extreme_points,
get_largest_connected_component_mask,
get_number_image_type_conversions,
get_transform_backends,
img_bounds,
in_bounds,
is_empty,
Expand Down
10 changes: 5 additions & 5 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class RandRicianNoise(RandomizableTransform):
uniformly from 0 to std.
"""

backends = [TransformBackends.TORCH, TransformBackends.NUMPY]
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
Expand Down Expand Up @@ -197,7 +197,7 @@ class ShiftIntensity(Transform):
offset: offset value to shift the intensity of image.
"""

backends = [TransformBackends.TORCH, TransformBackends.NUMPY]
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, offset: float) -> None:
self.offset = offset
Expand All @@ -219,7 +219,7 @@ class RandShiftIntensity(RandomizableTransform):
Randomly shift intensity with randomly picked offset.
"""

backends = [TransformBackends.TORCH, TransformBackends.NUMPY]
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1) -> None:
"""
Expand Down Expand Up @@ -273,7 +273,7 @@ class StdShiftIntensity(Transform):
dtype: output data type, defaults to float32.
"""

backends = [TransformBackends.TORCH, TransformBackends.NUMPY]
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self, factor: float, nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32
Expand Down Expand Up @@ -318,7 +318,7 @@ class RandStdShiftIntensity(RandomizableTransform):
by: ``v = v + factor * std(v)`` where the `factor` is randomly picked.
"""

backends = [TransformBackends.TORCH, TransformBackends.NUMPY]
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ class CastToType(Transform):
specified PyTorch data type.
"""

backends = [TransformBackends.TORCH, TransformBackends.NUMPY]
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, dtype=np.float32) -> None:
"""
Expand Down
96 changes: 63 additions & 33 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
min_version,
optional_import,
)
from monai.utils.enums import TransformBackends
from monai.utils.type_conversion import convert_data_type

measure, _ = optional_import("skimage.measure", "0.14.2", min_version)
Expand Down Expand Up @@ -81,6 +82,7 @@
"zero_margins",
"equalize_hist",
"get_number_image_type_conversions",
"get_transform_backends",
"print_transform_backends",
]

Expand Down Expand Up @@ -1158,22 +1160,17 @@ def _get_data(obj, key):
return num_conversions


def print_transform_backends():
"""Prints a list of backends of all MONAI transforms."""

class Colours:
red = "91"
green = "92"
yellow = "93"

def print_colour(t, colour):
print(f"\033[{colour}m{t}\033[00m")
def get_transform_backends():
"""Get the backends of all MONAI transforms.

tr_total = 0
tr_t_or_np = 0
tr_t = 0
tr_np = 0
tr_uncategorised = 0
Returns:
Dictionary, where each key is a transform, and its
corresponding values are a boolean list, stating
whether that transform supports (1) `torch.Tensor`,
and (2) `np.ndarray` as input without needing to
convert.
"""
backends = {}
unique_transforms = []
for n, obj in getmembers(monai.transforms):
# skip aliases
Expand All @@ -1194,21 +1191,54 @@ def print_colour(t, colour):
"InverteD",
]:
continue
tr_total += 1
if obj.backend == ["torch", "numpy"]:
tr_t_or_np += 1
print_colour(f"TorchOrNumpy: {n}", Colours.green)
elif obj.backend == ["torch"]:
tr_t += 1
print_colour(f"Torch: {n}", Colours.green)
elif obj.backend == ["numpy"]:
tr_np += 1
print_colour(f"Numpy: {n}", Colours.yellow)
else:
tr_uncategorised += 1
print_colour(f"Uncategorised: {n}", Colours.red)
print("Total number of transforms:", tr_total)
print_colour(f"Number transforms allowing both torch and numpy: {tr_t_or_np}", Colours.green)
print_colour(f"Number of TorchTransform: {tr_t}", Colours.green)
print_colour(f"Number of NumpyTransform: {tr_np}", Colours.yellow)
print_colour(f"Number of uncategorised: {tr_uncategorised}", Colours.red)

backends[n] = [
TransformBackends.TORCH in obj.backend,
TransformBackends.NUMPY in obj.backend,
]
return backends


def print_transform_backends():
"""Prints a list of backends of all MONAI transforms."""

class Colors:
none = ""
red = "91"
green = "92"
yellow = "93"

def print_color(t, color):
print(f"\033[{color}m{t}\033[00m")

def print_table_column(name, torch, numpy, color=Colors.none):
print_color("{:<50} {:<8} {:<8}".format(name, torch, numpy), color)

backends = get_transform_backends()
n_total = len(backends)
n_t_or_np, n_t, n_np, n_uncategorized = 0, 0, 0, 0
print_table_column("Transform", "Torch?", "Numpy?")
for k, v in backends.items():
if all(v):
color = Colors.green
n_t_or_np += 1
elif v[0]:
color = Colors.green
n_t += 1
elif v[1]:
color = Colors.yellow
n_np += 1
else:
color = Colors.red
n_uncategorized += 1
print_table_column(k, *v, color)

print("Total number of transforms:", n_total)
print_color(f"Number transforms allowing both torch and numpy: {n_t_or_np}", Colors.green)
print_color(f"Number of TorchTransform: {n_t}", Colors.green)
print_color(f"Number of NumpyTransform: {n_np}", Colors.yellow)
print_color(f"Number of uncategorised: {n_uncategorized}", Colors.red)


if __name__ == "__main__":
print_transform_backends()
8 changes: 6 additions & 2 deletions tests/test_print_transform_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@

import unittest

from monai.transforms.utils import print_transform_backends
from monai.transforms.utils import get_transform_backends, print_transform_backends


class TestPrintTransformBackends(unittest.TestCase):
def test_get_number_of_conversions(self):
tr_t_or_np, *_ = get_transform_backends()
self.assertGreater(len(tr_t_or_np), 0)
print_transform_backends()


if __name__ == "__main__":
unittest.main()
# unittest.main()
a = TestPrintTransformBackends()
a.test_get_number_of_conversions()