From c601e45aaf9686af8e3f1f819338c461553fd1e1 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 9 Feb 2023 17:08:26 +0000 Subject: [PATCH 1/9] update to use np.linalg Signed-off-by: Wenqi Li --- monai/networks/utils.py | 9 +++++---- monai/transforms/spatial/array.py | 23 ++++++++++------------- monai/transforms/utils.py | 2 +- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 23676d3d06..8b81d03535 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -22,6 +22,7 @@ from copy import deepcopy from typing import Any +import numpy as np import torch import torch.nn as nn @@ -29,7 +30,7 @@ from monai.config import PathLike from monai.utils.misc import ensure_tuple, save_obj, set_determinism from monai.utils.module import look_up_option, pytorch_after -from monai.utils.type_conversion import convert_to_tensor +from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor __all__ = [ "one_hot", @@ -185,7 +186,7 @@ def predict_segmentation(logits: torch.Tensor, mutually_exclusive: bool = False, def normalize_transform( shape, - device: torch.device | None = None, + device: torch.device | str | None = None, dtype: torch.dtype | None = None, align_corners: bool = False, zero_centered: bool = False, @@ -264,8 +265,8 @@ def to_norm_affine( raise ValueError(f"affine suggests {sr}D, got src={len(src_size)}D, dst={len(dst_size)}D.") src_xform = normalize_transform(src_size, affine.device, affine.dtype, align_corners, zero_centered) - dst_xform = normalize_transform(dst_size, affine.device, affine.dtype, align_corners, zero_centered) - return src_xform @ affine @ torch.inverse(dst_xform) + dst_xform = normalize_transform(dst_size, "cpu", affine.dtype, align_corners, zero_centered) + return src_xform @ affine @ convert_to_dst_type(np.linalg.inv(dst_xform.numpy()), dst=affine)[0] def normal_init( diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index e26a15e1dc..e5d73d980e 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -65,7 +65,6 @@ fall_back_tuple, issequenceiterable, optional_import, - pytorch_after, ) from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends, WSIPatchKeys @@ -272,14 +271,12 @@ def __call__( ) try: - _s = convert_to_tensor(src_affine_, track_meta=False, device=torch.device("cpu")) - _d = convert_to_tensor(dst_affine, track_meta=False, device=torch.device("cpu")) - xform = ( - torch.linalg.solve(_s, _d) if pytorch_after(1, 8, 0) else torch.solve(_d, _s).solution # type: ignore - ) + _s = convert_to_numpy(src_affine_) + _d = convert_to_numpy(dst_affine) + xform = np.linalg.solve(_s, _d) except (np.linalg.LinAlgError, RuntimeError) as e: - raise ValueError("src affine is not invertible.") from e - xform = to_affine_nd(spatial_rank, xform).to(device=img.device, dtype=_dtype) + raise ValueError(f"src affine is not invertible {_s}, {_d}.") from e + xform = convert_to_tensor(to_affine_nd(spatial_rank, xform)).to(device=img.device, dtype=_dtype) # no resampling if it's identity transform if allclose(xform, torch.eye(len(xform)), atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size): return self._post_process( @@ -293,12 +290,12 @@ def __call__( xform_shape = [-1] + in_spatial_size img = img.reshape(xform_shape) # type: ignore if isinstance(mode, int): - dst_xform_1 = normalize_transform(spatial_size, xform.device, xform.dtype, True, True)[0] # to (-1, 1) + dst_xform_1 = normalize_transform(spatial_size, "cpu", xform.dtype, True, True)[0].numpy() # to (-1, 1) if not align_corners: - norm = create_scale(spatial_rank, [(max(d, 2) - 1) / d for d in spatial_size], xform.device, "torch") - dst_xform_1 = norm.to(xform.dtype) @ dst_xform_1 # type: ignore # scaling (num_step - 1) / num_step - dst_xform_d = normalize_transform(spatial_size, xform.device, xform.dtype, align_corners, False)[0] - xform = xform @ torch.inverse(dst_xform_d) @ dst_xform_1 + norm = create_scale(spatial_rank, [(max(d, 2) - 1) / d for d in spatial_size]) + dst_xform_1 = norm.astype(float) @ dst_xform_1 # type: ignore # scaling (num_step - 1) / num_step + dst_xform_d = normalize_transform(spatial_size, "cpu", xform.dtype, align_corners, False)[0].numpy() + xform @= convert_to_dst_type(np.linalg.solve(dst_xform_d, dst_xform_1), xform)[0] affine_xform = Affine( affine=xform, spatial_size=spatial_size, normalized=True, image_only=True, dtype=_dtype # type: ignore ) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 6560899318..2e24463720 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -828,7 +828,7 @@ def _create_shear(spatial_dims: int, coefs: Sequence[float] | float, eye_func=np def create_scale( spatial_dims: int, scaling_factor: Sequence[float] | float, - device: torch.device | None = None, + device: torch.device | str | None = None, backend=TransformBackends.NUMPY, ) -> NdarrayOrTensor: """ From 5ac5fa99f9f9b3ff66d2dc1ba3df1062de4f0813 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 9 Feb 2023 17:20:59 +0000 Subject: [PATCH 2/9] update Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index e5d73d980e..f29b55e8ad 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1081,7 +1081,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] - inv_rot_mat = linalg_inv(fwd_rot_mat) + inv_rot_mat = linalg_inv(convert_to_numpy(fwd_rot_mat)) xform = AffineTransform( normalized=False, @@ -2278,7 +2278,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] mode = transform[TraceKeys.EXTRA_INFO]["mode"] padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] - inv_affine = linalg_inv(fwd_affine) + inv_affine = linalg_inv(convert_to_numpy(fwd_affine)) inv_affine = convert_to_dst_type(inv_affine, data, dtype=inv_affine.dtype)[0] affine_grid = AffineGrid(affine=inv_affine) @@ -2517,7 +2517,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] mode = transform[TraceKeys.EXTRA_INFO]["mode"] padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] - inv_affine = linalg_inv(fwd_affine) + inv_affine = linalg_inv(convert_to_numpy(fwd_affine)) inv_affine = convert_to_dst_type(inv_affine, data, dtype=inv_affine.dtype)[0] affine_grid = AffineGrid(affine=inv_affine) grid, _ = affine_grid(orig_size) From 91dc1157c003c158dbb75ebc6e4ffc9b8590ab6a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 9 Feb 2023 19:53:42 +0000 Subject: [PATCH 3/9] update integration results Signed-off-by: Wenqi Li --- tests/testing_data/integration_answers.py | 98 +++++++++++------------ 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/tests/testing_data/integration_answers.py b/tests/testing_data/integration_answers.py index f4a5483f83..989f286b23 100644 --- a/tests/testing_data/integration_answers.py +++ b/tests/testing_data/integration_answers.py @@ -14,7 +14,7 @@ import numpy as np EXPECTED_ANSWERS = [ - { # test answers for PyTorch 1.6 + { # test answers for PyTorch 1.13 "integration_classification_2d": { "losses": [0.776835828070428, 0.1615355300011149, 0.07492854832938523, 0.04591309238865877], "best_metric": 0.9999184380485994, @@ -22,56 +22,56 @@ }, "integration_segmentation_3d": { "losses": [ - 0.5367561340332031, - 0.478084459900856, - 0.4581540793180466, - 0.44623913466930387, - 0.42341493666172025, - 0.42569945752620697, + 0.5326887160539627, + 0.4685510128736496, + 0.46245276033878324, + 0.4411882758140564, + 0.4198471873998642, + 0.43021280467510226, ], - "best_metric": 0.9295084029436111, - "infer_metric": 0.9296411260962486, + "best_metric": 0.931993305683136, + "infer_metric": 0.9326668977737427, "output_sums": [ - 0.14302121377204619, - 0.15321686701244813, - 0.15267064069005093, - 0.1408481434833016, - 0.18862719991649474, - 0.16992848513054068, - 0.1479306037291329, - 0.1691071594535633, - 0.15804366588267224, - 0.18019304183940157, - 0.1635089455927468, - 0.16851606024285842, - 0.1454348651039073, - 0.11584957890961554, - 0.16255468027312903, - 0.20118089432240313, - 0.176187783307603, - 0.1004243279488101, - 0.19385348502657657, - 0.2030768555124136, - 0.196251372926592, - 0.20823046240222043, - 0.1631389353339986, - 0.13299661219478043, - 0.14917081129077908, - 0.14383374638201593, - 0.23050183928776746, - 0.1614747942341212, - 0.14913436515470202, - 0.10443081170610946, - 0.11978674347415241, - 0.13126176432899028, - 0.11570832453348577, - 0.15306806147195887, - 0.163673089782912, - 0.19394971756732426, - 0.22197501007172804, - 0.1812147930033603, - 0.19051659118682873, - 0.0774867922747158, + 0.1418775228871769, + 0.15188869120317386, + 0.15140863737688195, + 0.1396146850007127, + 0.18784343811575696, + 0.16909487431163164, + 0.14649608249452073, + 0.1677767130878611, + 0.1568122289811143, + 0.17874181729735056, + 0.16213703658980205, + 0.16754335171970686, + 0.14444824920997243, + 0.11432402622850306, + 0.16143210936221247, + 0.20055289634107482, + 0.17543571757219317, + 0.09920729163334538, + 0.19297325815057875, + 0.2023200127892273, + 0.1956677579845722, + 0.20774045016425718, + 0.16193278944159428, + 0.13174198906539808, + 0.14830508550670007, + 0.14241105864278342, + 0.23090631643085724, + 0.16056153813499532, + 0.1480353269419819, + 0.10318719171632634, + 0.11867462580989198, + 0.12997011485830187, + 0.11401220332210203, + 0.15242746700662088, + 0.1628489107974574, + 0.19327235354175412, + 0.22184902863377548, + 0.18028049625972334, + 0.18958059106892552, + 0.07884601267057013, ], }, "integration_workflows": { From c72d5378f7b7c8623d3175476e22f099411f8088 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 9 Feb 2023 20:07:58 +0000 Subject: [PATCH 4/9] update answers Signed-off-by: Wenqi Li --- tests/testing_data/integration_answers.py | 194 +++++++++++----------- 1 file changed, 97 insertions(+), 97 deletions(-) diff --git a/tests/testing_data/integration_answers.py b/tests/testing_data/integration_answers.py index 989f286b23..b3314557b3 100644 --- a/tests/testing_data/integration_answers.py +++ b/tests/testing_data/integration_answers.py @@ -22,56 +22,56 @@ }, "integration_segmentation_3d": { "losses": [ - 0.5326887160539627, - 0.4685510128736496, - 0.46245276033878324, - 0.4411882758140564, - 0.4198471873998642, - 0.43021280467510226, + 0.5428894340991974, + 0.47331981360912323, + 0.4482289582490921, + 0.4452722787857056, + 0.4289989799261093, + 0.4359133839607239, ], - "best_metric": 0.931993305683136, - "infer_metric": 0.9326668977737427, + "best_metric": 0.933259129524231, + "infer_metric": 0.9332860708236694, "output_sums": [ - 0.1418775228871769, - 0.15188869120317386, - 0.15140863737688195, - 0.1396146850007127, - 0.18784343811575696, - 0.16909487431163164, - 0.14649608249452073, - 0.1677767130878611, - 0.1568122289811143, - 0.17874181729735056, - 0.16213703658980205, - 0.16754335171970686, - 0.14444824920997243, - 0.11432402622850306, - 0.16143210936221247, - 0.20055289634107482, - 0.17543571757219317, - 0.09920729163334538, - 0.19297325815057875, - 0.2023200127892273, - 0.1956677579845722, - 0.20774045016425718, - 0.16193278944159428, - 0.13174198906539808, - 0.14830508550670007, - 0.14241105864278342, - 0.23090631643085724, - 0.16056153813499532, - 0.1480353269419819, - 0.10318719171632634, - 0.11867462580989198, - 0.12997011485830187, - 0.11401220332210203, - 0.15242746700662088, - 0.1628489107974574, - 0.19327235354175412, - 0.22184902863377548, - 0.18028049625972334, - 0.18958059106892552, - 0.07884601267057013, + 0.142167581604417, + 0.15195543400875847, + 0.1512754523215521, + 0.13962938779108452, + 0.18835719348918614, + 0.16943498693483486, + 0.1465709827477569, + 0.16806483607477135, + 0.1568844609697224, + 0.17911090857818554, + 0.16252098157181355, + 0.16806016936625395, + 0.14430124467305516, + 0.11316135548315168, + 0.16183771025615476, + 0.2009426314066978, + 0.1760258010156966, + 0.09700864497950844, + 0.1938495370314683, + 0.20319147575335647, + 0.19629641404249798, + 0.20852344793102826, + 0.16185073630020633, + 0.13184196857669161, + 0.1480959525354053, + 0.14232924377085415, + 0.23177739882790951, + 0.16094610375534632, + 0.14832771888168225, + 0.10259365443625812, + 0.11850632233099603, + 0.1294100326098242, + 0.11364228279017609, + 0.15181947897584674, + 0.16319358155815072, + 0.1940284526521386, + 0.22306137879066443, + 0.18083137638759522, + 0.1903135237574692, + 0.07402317520619131, ], }, "integration_workflows": { @@ -165,7 +165,7 @@ ], }, }, - { # test answers for PyTorch 1.7 + { # test answers for PyTorch 1.8 "integration_classification_2d": { "losses": [0.777176220515731, 0.16019743723664315, 0.07480076164197011, 0.045643698364780966], "best_metric": 0.9999418774120775, @@ -173,56 +173,56 @@ }, "integration_segmentation_3d": { "losses": [ - 0.5427072256803512, - 0.46434969305992124, - 0.45358552038669586, - 0.4363856494426727, - 0.42080804109573366, - 0.42058534920215607, + 0.5326887160539627, + 0.4685510128736496, + 0.46245276033878324, + 0.4411882758140564, + 0.4198471873998642, + 0.43021280467510226, ], - "best_metric": 0.9292903542518616, - "infer_metric": 0.9306288316845894, + "best_metric": 0.931993305683136, + "infer_metric": 0.9326668977737427, "output_sums": [ - 0.14192493409895743, - 0.15182314591386872, - 0.15143080738742032, - 0.13972497034181824, - 0.18790884439406313, - 0.16933812661492562, - 0.14664343345928132, - 0.1678599094806423, - 0.1568852615222309, - 0.17882538307200632, - 0.16226220644853354, - 0.16756325103417588, - 0.1449974856885373, - 0.1160602083671129, - 0.1614830941632057, - 0.20060717335382267, - 0.17543495742507476, - 0.10308107883493946, - 0.19289222718691168, - 0.20225689438356148, - 0.19587806881756237, - 0.20773073456322155, - 0.16193015294299506, - 0.13181961683097554, - 0.14850995284454005, - 0.14238637655756, - 0.2307113922277095, - 0.1608335768948913, - 0.1480752874532259, - 0.1038477413165911, - 0.11880665574424197, - 0.13084873656303445, - 0.1141965805147642, - 0.1531586543003841, - 0.16275008603701097, - 0.19320476187766733, - 0.2217811250932611, - 0.18027048819200148, - 0.18958803602663193, - 0.08653716931250294, + 0.1418775228871769, + 0.15188869120317386, + 0.15140863737688195, + 0.1396146850007127, + 0.18784343811575696, + 0.16909487431163164, + 0.14649608249452073, + 0.1677767130878611, + 0.1568122289811143, + 0.17874181729735056, + 0.16213703658980205, + 0.16754335171970686, + 0.14444824920997243, + 0.11432402622850306, + 0.16143210936221247, + 0.20055289634107482, + 0.17543571757219317, + 0.09920729163334538, + 0.19297325815057875, + 0.2023200127892273, + 0.1956677579845722, + 0.20774045016425718, + 0.16193278944159428, + 0.13174198906539808, + 0.14830508550670007, + 0.14241105864278342, + 0.23090631643085724, + 0.16056153813499532, + 0.1480353269419819, + 0.10318719171632634, + 0.11867462580989198, + 0.12997011485830187, + 0.11401220332210203, + 0.15242746700662088, + 0.1628489107974574, + 0.19327235354175412, + 0.22184902863377548, + 0.18028049625972334, + 0.18958059106892552, + 0.07884601267057013, ], }, "integration_workflows": { From ecfec2e174e46352607411ce3c85e2b2a7d43465 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 9 Feb 2023 23:41:08 +0000 Subject: [PATCH 5/9] update results Signed-off-by: Wenqi Li --- tests/testing_data/integration_answers.py | 98 +++++++++++------------ 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/tests/testing_data/integration_answers.py b/tests/testing_data/integration_answers.py index b3314557b3..57d89f8e7e 100644 --- a/tests/testing_data/integration_answers.py +++ b/tests/testing_data/integration_answers.py @@ -14,7 +14,7 @@ import numpy as np EXPECTED_ANSWERS = [ - { # test answers for PyTorch 1.13 + { # test answers for PyTorch 1.12.1 "integration_classification_2d": { "losses": [0.776835828070428, 0.1615355300011149, 0.07492854832938523, 0.04591309238865877], "best_metric": 0.9999184380485994, @@ -501,56 +501,56 @@ }, "integration_segmentation_3d": { "losses": [ - 0.5462362408638001, - 0.4913381844758987, - 0.4526856362819672, - 0.43404580652713776, - 0.42532919645309447, - 0.4160102754831314, + 0.5373653262853623, + 0.46776085495948794, + 0.4422474503517151, + 0.43667820692062376, + 0.42639826238155365, + 0.4158218264579773, ], - "best_metric": 0.9357608556747437, - "infer_metric": 0.9359462857246399, + "best_metric": 0.9357852935791016, + "infer_metric": 0.9358890652656555, "output_sums": [ - 0.14133183650702907, - 0.15129517085134564, - 0.15039408698301698, - 0.1388800895551786, - 0.18765019147239637, - 0.16847158867677473, - 0.14567945622102715, - 0.16728557092807228, - 0.15601444057659314, - 0.17816339678760573, - 0.1616256801482474, - 0.16733042976922818, - 0.14342795433701588, - 0.1122946416901734, - 0.16105778942392063, - 0.20017543167070598, - 0.17512204704647916, - 0.09592956823274325, - 0.19316383411238341, - 0.2022308530579937, - 0.19527218778022315, - 0.2075871950564991, - 0.16083565516485876, - 0.13111518931029637, - 0.1473909261474288, - 0.14161210629657228, - 0.23102446985179093, - 0.15980667305916593, - 0.14760356792082058, - 0.1018092235719272, - 0.11792260857122504, - 0.1285278390386459, - 0.11275165891441473, - 0.15101653432548032, - 0.16236351926994622, - 0.1932631773335222, - 0.2221395787381994, - 0.18003549292918666, - 0.18940543270178078, - 0.07430261166443994, + 0.14134230251963734, + 0.15126225587084188, + 0.15033401825118003, + 0.1389259850822021, + 0.18754569424488515, + 0.16839386100677756, + 0.14565994645049316, + 0.1673404545700305, + 0.1561511946878991, + 0.17825771988631423, + 0.1616607574532002, + 0.16742913895628342, + 0.14354699757474138, + 0.11215070364672397, + 0.16112241518382064, + 0.20001951769273596, + 0.17526580315958823, + 0.09564779134319003, + 0.19300729425711433, + 0.20226013883846938, + 0.1952803784613225, + 0.207563783379273, + 0.16082750039188845, + 0.13121728634981528, + 0.14741783973523187, + 0.14157844891046317, + 0.23102353186599955, + 0.15982195501286317, + 0.14750224809851548, + 0.10177519678431225, + 0.11784387764466563, + 0.12852018780730834, + 0.11300143976680752, + 0.1508621728586496, + 0.1623522601916851, + 0.19320168095077178, + 0.222086024709285, + 0.1800784736260849, + 0.18942329376838685, + 0.07354564965439693, ], }, "integration_workflows": { From 11305398ecd9c6b4163b5ebf801835fdd082df3a Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Fri, 10 Feb 2023 06:10:00 +0000 Subject: [PATCH 6/9] 5762 pprint head and tail bundle script (#5969) Signed-off-by: Wenqi Li Fixes #5762 ### Description limiting the number of printing lines ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li --- monai/bundle/scripts.py | 6 +++--- monai/utils/__init__.py | 1 + monai/utils/misc.py | 16 ++++++++++++++++ tests/test_bundle_utils.py | 12 ++++++++++++ 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index d7a46994a1..0a4c3139b1 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -14,7 +14,6 @@ import ast import json import os -import pprint import re import time import warnings @@ -37,7 +36,7 @@ from monai.data import load_net_with_metadata, save_net_with_metadata from monai.networks import convert_to_torchscript, copy_model_state, get_state_dict, save_state from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import -from monai.utils.misc import ensure_tuple +from monai.utils.misc import ensure_tuple, pprint_edges validate, _ = optional_import("jsonschema", name="validate") ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") @@ -48,6 +47,7 @@ # set BUNDLE_DOWNLOAD_SRC="ngc" to use NGC source in default for bundle download download_source = os.environ.get("BUNDLE_DOWNLOAD_SRC", "github") +PPRINT_CONFIG_N = 5 def _update_args(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict: @@ -88,7 +88,7 @@ def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple: def _log_input_summary(tag: str, args: dict) -> None: logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---") for name, val in args.items(): - logger.info(f"> {name}: {pprint.pformat(val)}") + logger.info(f"> {name}: {pprint_edges(val, PPRINT_CONFIG_N)}") logger.info("---\n\n") diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 601a5f10ae..318b16c47c 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -77,6 +77,7 @@ issequenceiterable, list_to_dict, path_to_uri, + pprint_edges, progress_bar, sample_slices, save_obj, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 583acd3a54..05ef1cb4c7 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -14,6 +14,7 @@ import inspect import itertools import os +import pprint import random import shutil import tempfile @@ -60,6 +61,7 @@ "save_obj", "label_union", "path_to_uri", + "pprint_edges", ] _seed = None @@ -626,3 +628,17 @@ def path_to_uri(path: PathLike) -> str: """ return Path(path).absolute().as_uri() + + +def pprint_edges(val: Any, n_lines: int = 20) -> str: + """ + Pretty print the head and tail ``n_lines`` of ``val``, and omit the middle part if the part has more than 3 lines. + + Returns: the formatted string. + """ + val_str = pprint.pformat(val).splitlines(True) + n_lines = max(n_lines, 1) + if len(val_str) > n_lines * 2 + 3: + hidden_n = len(val_str) - n_lines * 2 + val_str = val_str[:n_lines] + [f"\n ... omitted {hidden_n} line(s)\n\n"] + val_str[-n_lines:] + return "".join(val_str) diff --git a/tests/test_bundle_utils.py b/tests/test_bundle_utils.py index 9d28903f2f..d92f6e517f 100644 --- a/tests/test_bundle_utils.py +++ b/tests/test_bundle_utils.py @@ -20,6 +20,7 @@ from monai.bundle.utils import load_bundle_config from monai.networks.nets import UNet +from monai.utils import pprint_edges from tests.utils import command_line_tests, skip_if_windows metadata = """ @@ -117,5 +118,16 @@ def test_load_config_ts(self): self.assertEqual(p["test_dict"]["b"], "c") +class TestPPrintEdges(unittest.TestCase): + def test_str(self): + self.assertEqual(pprint_edges("", 0), "''") + self.assertEqual(pprint_edges({"a": 1, "b": 2}, 0), "{'a': 1, 'b': 2}") + self.assertEqual( + pprint_edges([{"a": 1, "b": 2}] * 20, 1), + "[{'a': 1, 'b': 2},\n\n ... omitted 18 line(s)\n\n {'a': 1, 'b': 2}]", + ) + self.assertEqual(pprint_edges([{"a": 1, "b": 2}] * 8, 4), pprint_edges([{"a": 1, "b": 2}] * 8, 3)) + + if __name__ == "__main__": unittest.main() From fac684837139db281cd293663ce9f3ba96a3bbb3 Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 10 Feb 2023 10:00:37 +0100 Subject: [PATCH 7/9] Added callable options for iteration_log and epoch_log in StatsHandler (#5965) Fixes #5964 ### Description Added callable options for iteration_log and epoch_log in StatsHandler. Ref: https://github.com/Project-MONAI/MONAI/discussions/5958#discussioncomment-4912997 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: vfdev-5 Signed-off-by: Wenqi Li Signed-off-by: vfdev Co-authored-by: Wenqi Li Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com> --- monai/handlers/stats_handler.py | 24 +++++++++++++++----- tests/test_handler_stats.py | 39 ++++++++++++++++++++++++++------- 2 files changed, 49 insertions(+), 14 deletions(-) diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 58917e666b..8471a87e8e 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -66,8 +66,8 @@ class StatsHandler: def __init__( self, - iteration_log: bool = True, - epoch_log: bool = True, + iteration_log: bool | Callable[[Engine, int], bool] = True, + epoch_log: bool | Callable[[Engine, int], bool] = True, epoch_print_logger: Callable[[Engine], Any] | None = None, iteration_print_logger: Callable[[Engine], Any] | None = None, output_transform: Callable = lambda x: x[0], @@ -80,8 +80,14 @@ def __init__( """ Args: - iteration_log: whether to log data when iteration completed, default to `True`. - epoch_log: whether to log data when epoch completed, default to `True`. + iteration_log: whether to log data when iteration completed, default to `True`. ``iteration_log`` can + be also a function and it will be interpreted as an event filter + (see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details). + Event filter function accepts as input engine and event value (iteration) and should return True/False. + Event filtering can be helpful to customize iteration logging frequency. + epoch_log: whether to log data when epoch completed, default to `True`. ``epoch_log`` can be + also a function and it will be interpreted as an event filter. See ``iteration_log`` argument for more + details. epoch_print_logger: customized callable printer for epoch level logging. Must accept parameter "engine", use default printer if None. iteration_print_logger: customized callable printer for iteration level logging. @@ -135,9 +141,15 @@ def attach(self, engine: Engine) -> None: " please call `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` to enable it." ) if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): - engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) + event = Events.ITERATION_COMPLETED + if callable(self.iteration_log): # substitute event with new one using filter callable + event = event(event_filter=self.iteration_log) + engine.add_event_handler(event, self.iteration_completed) if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): - engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed) + event = Events.EPOCH_COMPLETED + if callable(self.epoch_log): # substitute event with new one using filter callable + event = event(event_filter=self.epoch_log) + engine.add_event_handler(event, self.epoch_completed) if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED): engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised) diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index cb93f93a29..84477f9221 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -20,12 +20,23 @@ import torch from ignite.engine import Engine, Events +from parameterized import parameterized from monai.handlers import StatsHandler +def get_event_filter(e): + def event_filter(_, event): + if event in e: + return True + return False + + return event_filter + + class TestHandlerStats(unittest.TestCase): - def test_metrics_print(self): + @parameterized.expand([[True], [get_event_filter([1, 2])]]) + def test_metrics_print(self, epoch_log): log_stream = StringIO() log_handler = logging.StreamHandler(log_stream) log_handler.setLevel(logging.INFO) @@ -48,10 +59,11 @@ def _update_metric(engine): logger = logging.getLogger(key_to_handler) logger.setLevel(logging.INFO) logger.addHandler(log_handler) - stats_handler = StatsHandler(iteration_log=False, epoch_log=True, name=key_to_handler) + stats_handler = StatsHandler(iteration_log=False, epoch_log=epoch_log, name=key_to_handler) stats_handler.attach(engine) - engine.run(range(3), max_epochs=2) + max_epochs = 4 + engine.run(range(3), max_epochs=max_epochs) # check logging output output_str = log_stream.getvalue() @@ -61,9 +73,13 @@ def _update_metric(engine): for line in output_str.split("\n"): if has_key_word.match(line): content_count += 1 - self.assertTrue(content_count > 0) + if epoch_log is True: + self.assertTrue(content_count == max_epochs) + else: + self.assertTrue(content_count == 2) # 2 = len([1, 2]) from event_filter - def test_loss_print(self): + @parameterized.expand([[True], [get_event_filter([1, 3])]]) + def test_loss_print(self, iteration_log): log_stream = StringIO() log_handler = logging.StreamHandler(log_stream) log_handler.setLevel(logging.INFO) @@ -80,10 +96,14 @@ def _train_func(engine, batch): logger = logging.getLogger(key_to_handler) logger.setLevel(logging.INFO) logger.addHandler(log_handler) - stats_handler = StatsHandler(iteration_log=True, epoch_log=False, name=key_to_handler, tag_name=key_to_print) + stats_handler = StatsHandler( + iteration_log=iteration_log, epoch_log=False, name=key_to_handler, tag_name=key_to_print + ) stats_handler.attach(engine) - engine.run(range(3), max_epochs=2) + num_iters = 3 + max_epochs = 2 + engine.run(range(num_iters), max_epochs=max_epochs) # check logging output output_str = log_stream.getvalue() @@ -93,7 +113,10 @@ def _train_func(engine, batch): for line in output_str.split("\n"): if has_key_word.match(line): content_count += 1 - self.assertTrue(content_count > 0) + if iteration_log is True: + self.assertTrue(content_count == num_iters * max_epochs) + else: + self.assertTrue(content_count == 2) # 2 = len([1, 3]) from event_filter def test_loss_dict(self): log_stream = StringIO() From 565f93c31e81311776615c4c7cae5dd227d8a478 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 13 Feb 2023 15:57:40 +0000 Subject: [PATCH 8/9] remove unused ansers Signed-off-by: Wenqi Li --- tests/testing_data/integration_answers.py | 208 ---------------------- 1 file changed, 208 deletions(-) diff --git a/tests/testing_data/integration_answers.py b/tests/testing_data/integration_answers.py index 57d89f8e7e..b1f115e1d3 100644 --- a/tests/testing_data/integration_answers.py +++ b/tests/testing_data/integration_answers.py @@ -316,67 +316,6 @@ ], }, }, - { # test answers for PyTorch 21.04, cuda 11.3 - "integration_classification_2d": { - "losses": [0.7772567988770782, 0.16357883198815545, 0.0748426011840629, 0.045560025710873545], - "best_metric": 0.9999362036681547, - "infer_prop": [1030, 898, 981, 1033, 960, 1046], - }, - "integration_segmentation_3d": { - "losses": [ - 0.5462346076965332, - 0.4699550330638885, - 0.4407052755355835, - 0.4473582059144974, - 0.4345871120691299, - 0.4268435090780258, - ], - "best_metric": 0.9325245052576066, - "infer_metric": 0.9326683700084686, - "output_sums": [ - 0.14224469870198278, - 0.15221021012369151, - 0.15124158255724182, - 0.13988812880932433, - 0.18869885039284465, - 0.16944664085835437, - 0.14679946398855015, - 0.1681337815374021, - 0.1572538225010156, - 0.179386563044054, - 0.162734465243387, - 0.16831902111202945, - 0.1447043535420074, - 0.11343210557896033, - 0.16199135405262954, - 0.20095180481987404, - 0.17613484080473857, - 0.09717457016552708, - 0.1940439758638305, - 0.2033698355271389, - 0.19628583555443793, - 0.20852096425983455, - 0.16202004771083997, - 0.13206408917949392, - 0.14840973098125526, - 0.14237425379050472, - 0.23165483128059614, - 0.16098621485325398, - 0.14831028015056963, - 0.10317099380415945, - 0.118716576251689, - 0.13002315213569166, - 0.11436407827087304, - 0.1522274707636008, - 0.16314910792851098, - 0.1941135852761834, - 0.22309890968242424, - 0.18111804948625987, - 0.19043976068601465, - 0.07442812452084423, - ], - }, - }, { # test answers for PyTorch 1.9 "integration_workflows": { "output_sums_2": [ @@ -493,153 +432,6 @@ "infer_metric": 0.9316383600234985, }, }, - { # test answers for PyTorch 21.10 - "integration_classification_2d": { - "losses": [0.7806222991199251, 0.16259610306495315, 0.07529311385124353, 0.04640352608529246], - "best_metric": 0.9999369155431564, - "infer_prop": [1030, 898, 981, 1033, 960, 1046], - }, - "integration_segmentation_3d": { - "losses": [ - 0.5373653262853623, - 0.46776085495948794, - 0.4422474503517151, - 0.43667820692062376, - 0.42639826238155365, - 0.4158218264579773, - ], - "best_metric": 0.9357852935791016, - "infer_metric": 0.9358890652656555, - "output_sums": [ - 0.14134230251963734, - 0.15126225587084188, - 0.15033401825118003, - 0.1389259850822021, - 0.18754569424488515, - 0.16839386100677756, - 0.14565994645049316, - 0.1673404545700305, - 0.1561511946878991, - 0.17825771988631423, - 0.1616607574532002, - 0.16742913895628342, - 0.14354699757474138, - 0.11215070364672397, - 0.16112241518382064, - 0.20001951769273596, - 0.17526580315958823, - 0.09564779134319003, - 0.19300729425711433, - 0.20226013883846938, - 0.1952803784613225, - 0.207563783379273, - 0.16082750039188845, - 0.13121728634981528, - 0.14741783973523187, - 0.14157844891046317, - 0.23102353186599955, - 0.15982195501286317, - 0.14750224809851548, - 0.10177519678431225, - 0.11784387764466563, - 0.12852018780730834, - 0.11300143976680752, - 0.1508621728586496, - 0.1623522601916851, - 0.19320168095077178, - 0.222086024709285, - 0.1800784736260849, - 0.18942329376838685, - 0.07354564965439693, - ], - }, - "integration_workflows": { - "output_sums": [ - 0.14211511611938477, - 0.1516571044921875, - 0.1381092071533203, - 0.13403034210205078, - 0.18480682373046875, - 0.16382598876953125, - 0.14140796661376953, - 0.1665945053100586, - 0.15700864791870117, - 0.17697620391845703, - 0.16163396835327148, - 0.16488313674926758, - 0.1442713737487793, - 0.11060476303100586, - 0.16111087799072266, - 0.19617986679077148, - 0.1744403839111328, - 0.052786827087402344, - 0.19046974182128906, - 0.19913578033447266, - 0.19527721405029297, - 0.2032318115234375, - 0.16050148010253906, - 0.13228464126586914, - 0.1512293815612793, - 0.1372208595275879, - 0.22692251205444336, - 0.16164922714233398, - 0.14729642868041992, - 0.10398292541503906, - 0.1195836067199707, - 0.13096046447753906, - 0.11221647262573242, - 0.1521167755126953, - 0.1599421501159668, - 0.1898345947265625, - 0.21675777435302734, - 0.1777491569519043, - 0.18526840209960938, - 0.035144805908203125, - ], - "output_sums_2": [ - 0.14200592041015625, - 0.15146303176879883, - 0.13796186447143555, - 0.1339101791381836, - 0.18489742279052734, - 0.1637406349182129, - 0.14113903045654297, - 0.16657161712646484, - 0.15676355361938477, - 0.17683839797973633, - 0.1614980697631836, - 0.16493558883666992, - 0.14408016204833984, - 0.11035394668579102, - 0.1610560417175293, - 0.1962742805480957, - 0.17439842224121094, - 0.05285835266113281, - 0.19057941436767578, - 0.19914865493774414, - 0.19533538818359375, - 0.20333576202392578, - 0.16032838821411133, - 0.13197898864746094, - 0.1510462760925293, - 0.13703680038452148, - 0.2270984649658203, - 0.16144943237304688, - 0.1472611427307129, - 0.10393238067626953, - 0.11940813064575195, - 0.1307811737060547, - 0.11203241348266602, - 0.15186500549316406, - 0.15992307662963867, - 0.18991422653198242, - 0.21689796447753906, - 0.1777033805847168, - 0.18547868728637695, - 0.035192012786865234, - ], - }, - }, ] From af3e1d39696388d9ca4121bfa0db3b93126a0f3d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 13 Feb 2023 16:11:15 +0000 Subject: [PATCH 9/9] update based on comments Signed-off-by: Wenqi Li --- monai/networks/utils.py | 2 +- monai/transforms/spatial/array.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 8b81d03535..d5c0629c05 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -266,7 +266,7 @@ def to_norm_affine( src_xform = normalize_transform(src_size, affine.device, affine.dtype, align_corners, zero_centered) dst_xform = normalize_transform(dst_size, "cpu", affine.dtype, align_corners, zero_centered) - return src_xform @ affine @ convert_to_dst_type(np.linalg.inv(dst_xform.numpy()), dst=affine)[0] + return src_xform @ affine @ convert_to_dst_type(np.linalg.inv(dst_xform.numpy()), dst=affine)[0] # monai#5983 def normal_init( diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index e69e319005..bda32b5fc8 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -273,7 +273,7 @@ def __call__( try: _s = convert_to_numpy(src_affine_) _d = convert_to_numpy(dst_affine) - xform = np.linalg.solve(_s, _d) + xform = np.linalg.solve(_s, _d) # monai#5983 except (np.linalg.LinAlgError, RuntimeError) as e: raise ValueError(f"src affine is not invertible {_s}, {_d}.") from e xform = convert_to_tensor(to_affine_nd(spatial_rank, xform)).to(device=img.device, dtype=_dtype)