diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 23676d3d06..d5c0629c05 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -22,6 +22,7 @@ from copy import deepcopy from typing import Any +import numpy as np import torch import torch.nn as nn @@ -29,7 +30,7 @@ from monai.config import PathLike from monai.utils.misc import ensure_tuple, save_obj, set_determinism from monai.utils.module import look_up_option, pytorch_after -from monai.utils.type_conversion import convert_to_tensor +from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor __all__ = [ "one_hot", @@ -185,7 +186,7 @@ def predict_segmentation(logits: torch.Tensor, mutually_exclusive: bool = False, def normalize_transform( shape, - device: torch.device | None = None, + device: torch.device | str | None = None, dtype: torch.dtype | None = None, align_corners: bool = False, zero_centered: bool = False, @@ -264,8 +265,8 @@ def to_norm_affine( raise ValueError(f"affine suggests {sr}D, got src={len(src_size)}D, dst={len(dst_size)}D.") src_xform = normalize_transform(src_size, affine.device, affine.dtype, align_corners, zero_centered) - dst_xform = normalize_transform(dst_size, affine.device, affine.dtype, align_corners, zero_centered) - return src_xform @ affine @ torch.inverse(dst_xform) + dst_xform = normalize_transform(dst_size, "cpu", affine.dtype, align_corners, zero_centered) + return src_xform @ affine @ convert_to_dst_type(np.linalg.inv(dst_xform.numpy()), dst=affine)[0] # monai#5983 def normal_init( diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 02ee7a8c50..bda32b5fc8 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -65,7 +65,6 @@ fall_back_tuple, issequenceiterable, optional_import, - pytorch_after, ) from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends, WSIPatchKeys @@ -272,14 +271,12 @@ def __call__( ) try: - _s = convert_to_tensor(src_affine_, track_meta=False, device=torch.device("cpu")) - _d = convert_to_tensor(dst_affine, track_meta=False, device=torch.device("cpu")) - xform = ( - torch.linalg.solve(_s, _d) if pytorch_after(1, 8, 0) else torch.solve(_d, _s).solution # type: ignore - ) + _s = convert_to_numpy(src_affine_) + _d = convert_to_numpy(dst_affine) + xform = np.linalg.solve(_s, _d) # monai#5983 except (np.linalg.LinAlgError, RuntimeError) as e: - raise ValueError("src affine is not invertible.") from e - xform = to_affine_nd(spatial_rank, xform).to(device=img.device, dtype=_dtype) + raise ValueError(f"src affine is not invertible {_s}, {_d}.") from e + xform = convert_to_tensor(to_affine_nd(spatial_rank, xform)).to(device=img.device, dtype=_dtype) # no resampling if it's identity transform if allclose(xform, torch.eye(len(xform)), atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size): return self._post_process( @@ -293,12 +290,12 @@ def __call__( xform_shape = [-1] + in_spatial_size img = img.reshape(xform_shape) # type: ignore if isinstance(mode, int): - dst_xform_1 = normalize_transform(spatial_size, xform.device, xform.dtype, True, True)[0] # to (-1, 1) + dst_xform_1 = normalize_transform(spatial_size, "cpu", xform.dtype, True, True)[0].numpy() # to (-1, 1) if not align_corners: - norm = create_scale(spatial_rank, [(max(d, 2) - 1) / d for d in spatial_size], xform.device, "torch") - dst_xform_1 = norm.to(xform.dtype) @ dst_xform_1 # type: ignore # scaling (num_step - 1) / num_step - dst_xform_d = normalize_transform(spatial_size, xform.device, xform.dtype, align_corners, False)[0] - xform = xform @ torch.inverse(dst_xform_d) @ dst_xform_1 + norm = create_scale(spatial_rank, [(max(d, 2) - 1) / d for d in spatial_size]) + dst_xform_1 = norm.astype(float) @ dst_xform_1 # type: ignore # scaling (num_step - 1) / num_step + dst_xform_d = normalize_transform(spatial_size, "cpu", xform.dtype, align_corners, False)[0].numpy() + xform @= convert_to_dst_type(np.linalg.solve(dst_xform_d, dst_xform_1), xform)[0] affine_xform = Affine( affine=xform, spatial_size=spatial_size, normalized=True, image_only=True, dtype=_dtype # type: ignore ) @@ -1084,7 +1081,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] - inv_rot_mat = linalg_inv(fwd_rot_mat) + inv_rot_mat = linalg_inv(convert_to_numpy(fwd_rot_mat)) xform = AffineTransform( normalized=False, @@ -2281,7 +2278,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] mode = transform[TraceKeys.EXTRA_INFO]["mode"] padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] - inv_affine = linalg_inv(fwd_affine) + inv_affine = linalg_inv(convert_to_numpy(fwd_affine)) inv_affine = convert_to_dst_type(inv_affine, data, dtype=inv_affine.dtype)[0] affine_grid = AffineGrid(affine=inv_affine) @@ -2520,7 +2517,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] mode = transform[TraceKeys.EXTRA_INFO]["mode"] padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] - inv_affine = linalg_inv(fwd_affine) + inv_affine = linalg_inv(convert_to_numpy(fwd_affine)) inv_affine = convert_to_dst_type(inv_affine, data, dtype=inv_affine.dtype)[0] affine_grid = AffineGrid(affine=inv_affine) grid, _ = affine_grid(orig_size) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 6560899318..2e24463720 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -828,7 +828,7 @@ def _create_shear(spatial_dims: int, coefs: Sequence[float] | float, eye_func=np def create_scale( spatial_dims: int, scaling_factor: Sequence[float] | float, - device: torch.device | None = None, + device: torch.device | str | None = None, backend=TransformBackends.NUMPY, ) -> NdarrayOrTensor: """ diff --git a/tests/testing_data/integration_answers.py b/tests/testing_data/integration_answers.py index f4a5483f83..b1f115e1d3 100644 --- a/tests/testing_data/integration_answers.py +++ b/tests/testing_data/integration_answers.py @@ -14,7 +14,7 @@ import numpy as np EXPECTED_ANSWERS = [ - { # test answers for PyTorch 1.6 + { # test answers for PyTorch 1.12.1 "integration_classification_2d": { "losses": [0.776835828070428, 0.1615355300011149, 0.07492854832938523, 0.04591309238865877], "best_metric": 0.9999184380485994, @@ -22,56 +22,56 @@ }, "integration_segmentation_3d": { "losses": [ - 0.5367561340332031, - 0.478084459900856, - 0.4581540793180466, - 0.44623913466930387, - 0.42341493666172025, - 0.42569945752620697, + 0.5428894340991974, + 0.47331981360912323, + 0.4482289582490921, + 0.4452722787857056, + 0.4289989799261093, + 0.4359133839607239, ], - "best_metric": 0.9295084029436111, - "infer_metric": 0.9296411260962486, + "best_metric": 0.933259129524231, + "infer_metric": 0.9332860708236694, "output_sums": [ - 0.14302121377204619, - 0.15321686701244813, - 0.15267064069005093, - 0.1408481434833016, - 0.18862719991649474, - 0.16992848513054068, - 0.1479306037291329, - 0.1691071594535633, - 0.15804366588267224, - 0.18019304183940157, - 0.1635089455927468, - 0.16851606024285842, - 0.1454348651039073, - 0.11584957890961554, - 0.16255468027312903, - 0.20118089432240313, - 0.176187783307603, - 0.1004243279488101, - 0.19385348502657657, - 0.2030768555124136, - 0.196251372926592, - 0.20823046240222043, - 0.1631389353339986, - 0.13299661219478043, - 0.14917081129077908, - 0.14383374638201593, - 0.23050183928776746, - 0.1614747942341212, - 0.14913436515470202, - 0.10443081170610946, - 0.11978674347415241, - 0.13126176432899028, - 0.11570832453348577, - 0.15306806147195887, - 0.163673089782912, - 0.19394971756732426, - 0.22197501007172804, - 0.1812147930033603, - 0.19051659118682873, - 0.0774867922747158, + 0.142167581604417, + 0.15195543400875847, + 0.1512754523215521, + 0.13962938779108452, + 0.18835719348918614, + 0.16943498693483486, + 0.1465709827477569, + 0.16806483607477135, + 0.1568844609697224, + 0.17911090857818554, + 0.16252098157181355, + 0.16806016936625395, + 0.14430124467305516, + 0.11316135548315168, + 0.16183771025615476, + 0.2009426314066978, + 0.1760258010156966, + 0.09700864497950844, + 0.1938495370314683, + 0.20319147575335647, + 0.19629641404249798, + 0.20852344793102826, + 0.16185073630020633, + 0.13184196857669161, + 0.1480959525354053, + 0.14232924377085415, + 0.23177739882790951, + 0.16094610375534632, + 0.14832771888168225, + 0.10259365443625812, + 0.11850632233099603, + 0.1294100326098242, + 0.11364228279017609, + 0.15181947897584674, + 0.16319358155815072, + 0.1940284526521386, + 0.22306137879066443, + 0.18083137638759522, + 0.1903135237574692, + 0.07402317520619131, ], }, "integration_workflows": { @@ -165,7 +165,7 @@ ], }, }, - { # test answers for PyTorch 1.7 + { # test answers for PyTorch 1.8 "integration_classification_2d": { "losses": [0.777176220515731, 0.16019743723664315, 0.07480076164197011, 0.045643698364780966], "best_metric": 0.9999418774120775, @@ -173,56 +173,56 @@ }, "integration_segmentation_3d": { "losses": [ - 0.5427072256803512, - 0.46434969305992124, - 0.45358552038669586, - 0.4363856494426727, - 0.42080804109573366, - 0.42058534920215607, + 0.5326887160539627, + 0.4685510128736496, + 0.46245276033878324, + 0.4411882758140564, + 0.4198471873998642, + 0.43021280467510226, ], - "best_metric": 0.9292903542518616, - "infer_metric": 0.9306288316845894, + "best_metric": 0.931993305683136, + "infer_metric": 0.9326668977737427, "output_sums": [ - 0.14192493409895743, - 0.15182314591386872, - 0.15143080738742032, - 0.13972497034181824, - 0.18790884439406313, - 0.16933812661492562, - 0.14664343345928132, - 0.1678599094806423, - 0.1568852615222309, - 0.17882538307200632, - 0.16226220644853354, - 0.16756325103417588, - 0.1449974856885373, - 0.1160602083671129, - 0.1614830941632057, - 0.20060717335382267, - 0.17543495742507476, - 0.10308107883493946, - 0.19289222718691168, - 0.20225689438356148, - 0.19587806881756237, - 0.20773073456322155, - 0.16193015294299506, - 0.13181961683097554, - 0.14850995284454005, - 0.14238637655756, - 0.2307113922277095, - 0.1608335768948913, - 0.1480752874532259, - 0.1038477413165911, - 0.11880665574424197, - 0.13084873656303445, - 0.1141965805147642, - 0.1531586543003841, - 0.16275008603701097, - 0.19320476187766733, - 0.2217811250932611, - 0.18027048819200148, - 0.18958803602663193, - 0.08653716931250294, + 0.1418775228871769, + 0.15188869120317386, + 0.15140863737688195, + 0.1396146850007127, + 0.18784343811575696, + 0.16909487431163164, + 0.14649608249452073, + 0.1677767130878611, + 0.1568122289811143, + 0.17874181729735056, + 0.16213703658980205, + 0.16754335171970686, + 0.14444824920997243, + 0.11432402622850306, + 0.16143210936221247, + 0.20055289634107482, + 0.17543571757219317, + 0.09920729163334538, + 0.19297325815057875, + 0.2023200127892273, + 0.1956677579845722, + 0.20774045016425718, + 0.16193278944159428, + 0.13174198906539808, + 0.14830508550670007, + 0.14241105864278342, + 0.23090631643085724, + 0.16056153813499532, + 0.1480353269419819, + 0.10318719171632634, + 0.11867462580989198, + 0.12997011485830187, + 0.11401220332210203, + 0.15242746700662088, + 0.1628489107974574, + 0.19327235354175412, + 0.22184902863377548, + 0.18028049625972334, + 0.18958059106892552, + 0.07884601267057013, ], }, "integration_workflows": { @@ -316,67 +316,6 @@ ], }, }, - { # test answers for PyTorch 21.04, cuda 11.3 - "integration_classification_2d": { - "losses": [0.7772567988770782, 0.16357883198815545, 0.0748426011840629, 0.045560025710873545], - "best_metric": 0.9999362036681547, - "infer_prop": [1030, 898, 981, 1033, 960, 1046], - }, - "integration_segmentation_3d": { - "losses": [ - 0.5462346076965332, - 0.4699550330638885, - 0.4407052755355835, - 0.4473582059144974, - 0.4345871120691299, - 0.4268435090780258, - ], - "best_metric": 0.9325245052576066, - "infer_metric": 0.9326683700084686, - "output_sums": [ - 0.14224469870198278, - 0.15221021012369151, - 0.15124158255724182, - 0.13988812880932433, - 0.18869885039284465, - 0.16944664085835437, - 0.14679946398855015, - 0.1681337815374021, - 0.1572538225010156, - 0.179386563044054, - 0.162734465243387, - 0.16831902111202945, - 0.1447043535420074, - 0.11343210557896033, - 0.16199135405262954, - 0.20095180481987404, - 0.17613484080473857, - 0.09717457016552708, - 0.1940439758638305, - 0.2033698355271389, - 0.19628583555443793, - 0.20852096425983455, - 0.16202004771083997, - 0.13206408917949392, - 0.14840973098125526, - 0.14237425379050472, - 0.23165483128059614, - 0.16098621485325398, - 0.14831028015056963, - 0.10317099380415945, - 0.118716576251689, - 0.13002315213569166, - 0.11436407827087304, - 0.1522274707636008, - 0.16314910792851098, - 0.1941135852761834, - 0.22309890968242424, - 0.18111804948625987, - 0.19043976068601465, - 0.07442812452084423, - ], - }, - }, { # test answers for PyTorch 1.9 "integration_workflows": { "output_sums_2": [ @@ -493,153 +432,6 @@ "infer_metric": 0.9316383600234985, }, }, - { # test answers for PyTorch 21.10 - "integration_classification_2d": { - "losses": [0.7806222991199251, 0.16259610306495315, 0.07529311385124353, 0.04640352608529246], - "best_metric": 0.9999369155431564, - "infer_prop": [1030, 898, 981, 1033, 960, 1046], - }, - "integration_segmentation_3d": { - "losses": [ - 0.5462362408638001, - 0.4913381844758987, - 0.4526856362819672, - 0.43404580652713776, - 0.42532919645309447, - 0.4160102754831314, - ], - "best_metric": 0.9357608556747437, - "infer_metric": 0.9359462857246399, - "output_sums": [ - 0.14133183650702907, - 0.15129517085134564, - 0.15039408698301698, - 0.1388800895551786, - 0.18765019147239637, - 0.16847158867677473, - 0.14567945622102715, - 0.16728557092807228, - 0.15601444057659314, - 0.17816339678760573, - 0.1616256801482474, - 0.16733042976922818, - 0.14342795433701588, - 0.1122946416901734, - 0.16105778942392063, - 0.20017543167070598, - 0.17512204704647916, - 0.09592956823274325, - 0.19316383411238341, - 0.2022308530579937, - 0.19527218778022315, - 0.2075871950564991, - 0.16083565516485876, - 0.13111518931029637, - 0.1473909261474288, - 0.14161210629657228, - 0.23102446985179093, - 0.15980667305916593, - 0.14760356792082058, - 0.1018092235719272, - 0.11792260857122504, - 0.1285278390386459, - 0.11275165891441473, - 0.15101653432548032, - 0.16236351926994622, - 0.1932631773335222, - 0.2221395787381994, - 0.18003549292918666, - 0.18940543270178078, - 0.07430261166443994, - ], - }, - "integration_workflows": { - "output_sums": [ - 0.14211511611938477, - 0.1516571044921875, - 0.1381092071533203, - 0.13403034210205078, - 0.18480682373046875, - 0.16382598876953125, - 0.14140796661376953, - 0.1665945053100586, - 0.15700864791870117, - 0.17697620391845703, - 0.16163396835327148, - 0.16488313674926758, - 0.1442713737487793, - 0.11060476303100586, - 0.16111087799072266, - 0.19617986679077148, - 0.1744403839111328, - 0.052786827087402344, - 0.19046974182128906, - 0.19913578033447266, - 0.19527721405029297, - 0.2032318115234375, - 0.16050148010253906, - 0.13228464126586914, - 0.1512293815612793, - 0.1372208595275879, - 0.22692251205444336, - 0.16164922714233398, - 0.14729642868041992, - 0.10398292541503906, - 0.1195836067199707, - 0.13096046447753906, - 0.11221647262573242, - 0.1521167755126953, - 0.1599421501159668, - 0.1898345947265625, - 0.21675777435302734, - 0.1777491569519043, - 0.18526840209960938, - 0.035144805908203125, - ], - "output_sums_2": [ - 0.14200592041015625, - 0.15146303176879883, - 0.13796186447143555, - 0.1339101791381836, - 0.18489742279052734, - 0.1637406349182129, - 0.14113903045654297, - 0.16657161712646484, - 0.15676355361938477, - 0.17683839797973633, - 0.1614980697631836, - 0.16493558883666992, - 0.14408016204833984, - 0.11035394668579102, - 0.1610560417175293, - 0.1962742805480957, - 0.17439842224121094, - 0.05285835266113281, - 0.19057941436767578, - 0.19914865493774414, - 0.19533538818359375, - 0.20333576202392578, - 0.16032838821411133, - 0.13197898864746094, - 0.1510462760925293, - 0.13703680038452148, - 0.2270984649658203, - 0.16144943237304688, - 0.1472611427307129, - 0.10393238067626953, - 0.11940813064575195, - 0.1307811737060547, - 0.11203241348266602, - 0.15186500549316406, - 0.15992307662963867, - 0.18991422653198242, - 0.21689796447753906, - 0.1777033805847168, - 0.18547868728637695, - 0.035192012786865234, - ], - }, - }, ]