diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index 273eec3763..37fa05743e 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -3,6 +3,8 @@ name: crons on: schedule: - cron: "0 2 * * *" # at 02:00 UTC + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: jobs: cron-gpu: diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 227e0b3b71..e94930591e 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -45,7 +45,8 @@ jobs: python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null & python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))" python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))' - BUILD_MONAI=1 ./runtests.sh --unittests --net + BUILD_MONAI=1 ./runtests.sh --net + BUILD_MONAI=1 ./runtests.sh --unittests if pgrep python; then pkill python; fi shell: bash - name: Add reaction diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index edfaee3758..3035a1910d 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -50,21 +50,32 @@ def _transform(self, index: int) -> Dict[Hashable, np.ndarray]: class BatchInverseTransform(Transform): - """Perform inverse on a batch of data. This is useful if you have inferred a batch of images and want to invert them all.""" + """ + Perform inverse on a batch of data. This is useful if you have inferred a batch of images and want to invert + them all. + """ def __init__( - self, transform: InvertibleTransform, loader: TorchDataLoader, collate_fn: Optional[Callable] = no_collation + self, + transform: InvertibleTransform, + loader: TorchDataLoader, + collate_fn: Optional[Callable] = no_collation, + num_workers: Optional[int] = 0, ) -> None: """ Args: transform: a callable data transform on input data. - loader: data loader used to generate the batch of data. - collate_fn: how to collate data after inverse transformations. Default won't do any collation, so the output will be a - list of size batch size. + loader: data loader used to run `transforms` and generate the batch of data. + collate_fn: how to collate data after inverse transformations. + default won't do any collation, so the output will be a list of size batch size. + num_workers: number of workers when run data loader for inverse transforms, + default to 0 as only run 1 iteration and multi-processing may be even slower. + if the transforms are really slow, set num_workers for multi-processing. + if set to `None`, use the `num_workers` of the transform data loader. """ self.transform = transform self.batch_size = loader.batch_size - self.num_workers = loader.num_workers + self.num_workers = loader.num_workers if num_workers is None else num_workers self.collate_fn = collate_fn self.pad_collation_used = loader.collate_fn == pad_list_data_collate diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index bfa69c0bdd..e1fecb745d 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -223,17 +223,18 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # execute forward computation with self.mode(self.network): if self.amp: with torch.cuda.amp.autocast(): - output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) else: - output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) engine.fire_event(IterationEvents.FORWARD_COMPLETED) + engine.fire_event(IterationEvents.MODEL_COMPLETED) - return output + return engine.state.output class EnsembleEvaluator(Evaluator): @@ -344,14 +345,17 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} for idx, network in enumerate(self.networks): with self.mode(network): if self.amp: with torch.cuda.amp.autocast(): - output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) + engine.state.output.update( + {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)} + ) else: - output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) + engine.state.output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) engine.fire_event(IterationEvents.FORWARD_COMPLETED) + engine.fire_event(IterationEvents.MODEL_COMPLETED) - return output + return engine.state.output diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index f14ee7e91f..e9e31a1b16 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -157,12 +157,12 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} def _compute_pred_loss(): - output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) engine.fire_event(IterationEvents.FORWARD_COMPLETED) - output[Keys.LOSS] = self.loss_function(output[Keys.PRED], targets).mean() + engine.state.output[Keys.LOSS] = self.loss_function(engine.state.output[Keys.PRED], targets).mean() engine.fire_event(IterationEvents.LOSS_COMPLETED) self.network.train() @@ -170,18 +170,18 @@ def _compute_pred_loss(): if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): _compute_pred_loss() - self.scaler.scale(output[Keys.LOSS]).backward() + self.scaler.scale(engine.state.output[Keys.LOSS]).backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.scaler.step(self.optimizer) self.scaler.update() else: _compute_pred_loss() - output[Keys.LOSS].backward() + engine.state.output[Keys.LOSS].backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.optimizer.step() - engine.fire_event(IterationEvents.OPTIMIZER_COMPLETED) + engine.fire_event(IterationEvents.MODEL_COMPLETED) - return output + return engine.state.output class GanTrainer(Trainer): diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 04237d0f4a..d16ab3cfbb 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -38,13 +38,14 @@ class IterationEvents(EventEnum): `FORWARD_COMPLETED` is the Event when `network(image, label)` completed. `LOSS_COMPLETED` is the Event when `loss(pred, label)` completed. `BACKWARD_COMPLETED` is the Event when `loss.backward()` completed. + `MODEL_COMPLETED` is the Event when all the model related operations completed. """ FORWARD_COMPLETED = "forward_completed" LOSS_COMPLETED = "loss_completed" BACKWARD_COMPLETED = "backward_completed" - OPTIMIZER_COMPLETED = "optimizer_completed" + MODEL_COMPLETED = "model_completed" class GanKeys: diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 50a9f41368..4018dabc40 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from monai.engines.utils import default_prepare_batch +from monai.engines.utils import IterationEvents, default_prepare_batch from monai.transforms import apply_transform from monai.utils import ensure_tuple, exact_version, optional_import @@ -160,7 +160,7 @@ def _register_post_transforms(self, posttrans: Callable): """ - @self.on(Events.ITERATION_COMPLETED) + @self.on(IterationEvents.MODEL_COMPLETED) def run_post_transform(engine: Engine) -> None: engine.state.output = apply_transform(posttrans, engine.state.output) diff --git a/monai/handlers/earlystop_handler.py b/monai/handlers/earlystop_handler.py index 99e072b81f..0d140a9994 100644 --- a/monai/handlers/earlystop_handler.py +++ b/monai/handlers/earlystop_handler.py @@ -38,10 +38,10 @@ class EarlyStopHandler: cumulative_delta: if True, `min_delta` defines an increase since the last `patience` reset, otherwise, it defines an increase after the last event, default to False. epoch_level: check early stopping for every epoch or every iteration of the attached engine, - `True` is epoch level, `False` is iteration level, defaut to epoch level. + `True` is epoch level, `False` is iteration level, default to epoch level. Note: - If in distributed training and uses loss value of every iteration to detect earlystopping, + If in distributed training and uses loss value of every iteration to detect early stopping, the values may be different in different ranks. User may attach this handler to validator engine to detect validation metrics and stop the training, in this case, the `score_function` is executed on validator engine and `trainer` is the trainer engine. diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index f49c799a21..42f9828afd 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -73,10 +73,18 @@ def update(self, output: Sequence[torch.Tensor]) -> None: """ if len(output) != 2: raise ValueError(f"output must have length 2, got {len(output)}.") + y_pred, y = output - score = self.metric_fn(y_pred, y) - if isinstance(score, (tuple, list)): - score = score[0] + + def _compute(y_pred, y): + score = self.metric_fn(y_pred, y) + return score[0] if isinstance(score, (tuple, list)) else score + + if isinstance(y_pred, (list, tuple)) and isinstance(y, (list, tuple)): + # if a list of channel-first data, add batch dim and compute metric, then concat the scores + score = torch.cat([_compute(p_.unsqueeze(0), y_.unsqueeze(0)) for p_, y_ in zip(y_pred, y)], dim=0) + else: + score = _compute(y_pred, y) self._scores.append(score.to(self._device)) def compute(self) -> Any: diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py index 68201e44be..64f5c37d78 100644 --- a/monai/handlers/transform_inverter.py +++ b/monai/handlers/transform_inverter.py @@ -10,15 +10,15 @@ # limitations under the License. import warnings -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union from torch.utils.data import DataLoader as TorchDataLoader from monai.data import BatchInverseTransform from monai.data.utils import no_collation -from monai.engines.utils import CommonKeys -from monai.transforms import InvertibleTransform, allow_missing_keys_mode, convert_inverse_interp_mode -from monai.utils import InverseKeys, exact_version, optional_import +from monai.engines.utils import CommonKeys, IterationEvents +from monai.transforms import InvertibleTransform, ToTensor, allow_missing_keys_mode, convert_inverse_interp_mode +from monai.utils import InverseKeys, ensure_tuple, ensure_tuple_rep, exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: @@ -29,68 +29,82 @@ class TransformInverter: """ - Ignite handler to automatically invert all the pre-transforms that support `inverse`. - It takes `engine.state.output` as the input data and uses the transforms infomation from `engine.state.batch`. - + Ignite handler to automatically invert `transforms`. + It takes `engine.state.output` as the input data and uses the transforms information from `engine.state.batch`. + The outputs are stored in `engine.state.output` with the `output_keys`. """ def __init__( self, transform: InvertibleTransform, loader: TorchDataLoader, + output_keys: Union[str, Sequence[str]] = CommonKeys.PRED, + batch_keys: Union[str, Sequence[str]] = CommonKeys.IMAGE, collate_fn: Optional[Callable] = no_collation, - batch_key: str = CommonKeys.IMAGE, - output_key: str = CommonKeys.PRED, - postfix: str = "inverted", - nearest_interp: bool = True, + postfix: str = "_inverted", + nearest_interp: Union[bool, Sequence[bool]] = True, + num_workers: Optional[int] = 0, ) -> None: """ Args: transform: a callable data transform on input data. - loader: data loader used to generate the batch of data. + loader: data loader used to run transforms and generate the batch of data. collate_fn: how to collate data after inverse transformations. default won't do any collation, so the output will be a list of size batch size. - batch_key: the key of input data in `ignite.engine.batch`. will get the applied transforms - for this input data, then invert them for the model output, default to "image". - output_key: the key of model output in `ignite.engine.output`, invert transforms on it. - postfix: will save the inverted result into `ignite.engine.output` with key `{ouput_key}_{postfix}`. - nearest_interp: whether to use `nearest` interpolation mode when inverting spatial transforms, - default to `True`. if `False`, use the same interpolation mode as the original transform. + output_keys: the key of expected data in `ignite.engine.output`, invert transforms on it. + it also can be a list of keys, will invert transform for each of them. Default to "pred". + batch_keys: the key of input data in `ignite.engine.batch`. will get the applied transforms + for this input data, then invert them for the expected data with `output_keys`. + It can also be a list of keys, each matches to the `output_keys` data. default to "image". + postfix: will save the inverted result into `ignite.engine.output` with key `{output_key}{postfix}`. + nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms, + default to `True`. If `False`, use the same interpolation mode as the original transform. + it also can be a list of bool, each matches to the `output_keys` data. + num_workers: number of workers when run data loader for inverse transforms, + default to 0 as only run one iteration and multi-processing may be even slower. + Set to `None`, to use the `num_workers` of the input transform data loader. """ self.transform = transform - self.inverter = BatchInverseTransform(transform=transform, loader=loader, collate_fn=collate_fn) - self.batch_key = batch_key - self.output_key = output_key + self.inverter = BatchInverseTransform( + transform=transform, + loader=loader, + collate_fn=collate_fn, + num_workers=num_workers, + ) + self.output_keys = ensure_tuple(output_keys) + self.batch_keys = ensure_tuple_rep(batch_keys, len(self.output_keys)) self.postfix = postfix - self.nearest_interp = nearest_interp + self.nearest_interp = ensure_tuple_rep(nearest_interp, len(self.output_keys)) + self._totensor = ToTensor() def attach(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - engine.add_event_handler(Events.ITERATION_COMPLETED, self) + engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self) def __call__(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - transform_key = self.batch_key + InverseKeys.KEY_SUFFIX - if transform_key not in engine.state.batch: - warnings.warn("all the pre-transforms are not InvertibleTransform or no need to invert.") - return + for output_key, batch_key, nearest_interp in zip(self.output_keys, self.batch_keys, self.nearest_interp): + transform_key = batch_key + InverseKeys.KEY_SUFFIX + if transform_key not in engine.state.batch: + warnings.warn(f"all the transforms on `{batch_key}` are not InvertibleTransform.") + continue - transform_info = engine.state.batch[transform_key] - if self.nearest_interp: - convert_inverse_interp_mode(trans_info=transform_info, mode="nearest", align_corners=None) + transform_info = engine.state.batch[transform_key] + if nearest_interp: + convert_inverse_interp_mode(trans_info=transform_info, mode="nearest", align_corners=None) - segs_dict = { - self.batch_key: engine.state.output[self.output_key].detach().cpu(), - transform_key: transform_info, - } + segs_dict = { + batch_key: engine.state.output[output_key].detach().cpu(), + transform_key: transform_info, + } - with allow_missing_keys_mode(self.transform): # type: ignore - inverted_key = f"{self.output_key}_{self.postfix}" - engine.state.output[inverted_key] = [i[self.batch_key] for i in self.inverter(segs_dict)] + with allow_missing_keys_mode(self.transform): # type: ignore + inverted_key = f"{output_key}{self.postfix}" + engine.state.output[inverted_key] = [self._totensor(i[batch_key]) for i in self.inverter(segs_dict)] diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index cc90e6ed1d..3997d42436 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -260,7 +260,7 @@ def forward(self, x, mid) -> torch.Tensor: Args: x: feature to be up-sampled, in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3]) mid: mid-level feature saved during down-sampling, - in shape (batch, ``out_channels``, midsize_1, midsize_2, [midnsize_3]) + in shape (batch, ``out_channels``, midsize_1, midsize_2, [midsize_3]) Raises: ValueError: when ``midsize != insize * 2`` diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 7d0b3bff79..a69814f61c 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -91,7 +91,7 @@ class DynUNet(nn.Module): (1, 2, 8, 6). The last two will be interpolated into (1, 2, 32, 24), and the stacked tensor will has the shape (1, 3, 2, 8, 6). When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss - one by one with the groud truth, then do a weighted average for all losses to achieve the final loss. + one by one with the ground truth, then do a weighted average for all losses to achieve the final loss. (To be added: a corresponding tutorial link) deep_supr_num: number of feature maps that will output during deep supervision head. The diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py index d8754e3f78..3fe9ff35d8 100644 --- a/monai/networks/nets/efficientnet.py +++ b/monai/networks/nets/efficientnet.py @@ -503,19 +503,19 @@ def __init__( ) # get network parameters - weight_coeff, depth_coeff, image_size, drpout_rate, drpconnect_rate = efficientnet_params[model_name] + weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name] # create model and initialize random weights - model = super(EfficientNetBN, self).__init__( + super(EfficientNetBN, self).__init__( blocks_args_str=blocks_args_str, spatial_dims=spatial_dims, in_channels=in_channels, num_classes=num_classes, width_coefficient=weight_coeff, depth_coefficient=depth_coeff, - dropout_rate=drpout_rate, + dropout_rate=dropout_rate, image_size=image_size, - drop_connect_rate=drpconnect_rate, + drop_connect_rate=dropconnect_rate, ) # attempt to load pretrained @@ -827,7 +827,7 @@ def _decode_block_string(block_string: str): or (len(options["s"]) == 3 and options["s"][0] == options["s"][1] and options["s"][0] == options["s"][2]) ) if not stride_check: - raise ValueError("invalid stride option recieved") + raise ValueError("invalid stride option received") return BlockArgs( num_repeat=int(options["r"]), diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 6903b2628d..8f060eed13 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -160,7 +160,7 @@ def __call__(self, img: np.ndarray, meta_dict: Optional[Dict] = None): Apply the transform to `img`. """ if not isinstance(meta_dict, dict): - raise ValueError("meta_dict must be a dictionay data.") + raise ValueError("meta_dict must be a dictionary data.") channel_dim = meta_dict.get("original_channel_dim", None) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index b73a899153..c08b786e98 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -706,7 +706,7 @@ def map_spatial_axes( The default `None` will convert to all the spatial axes of the image. If axis is negative it counts from the last to the first axis. If axis is a tuple of ints. - channel_first: the image data is channel first or channel last, defaut to channel first. + channel_first: the image data is channel first or channel last, default to channel first. """ if spatial_axes is None: @@ -772,7 +772,9 @@ def allow_missing_keys_mode(transform: Union[MapTransform, Compose, Tuple[MapTra def convert_inverse_interp_mode(trans_info: List, mode: str = "nearest", align_corners: Optional[bool] = None): """ Change the interpolation mode when inverting spatial transforms, default to "nearest". - It can support both single data or batch data. + This function modifies trans_info's `InverseKeys.EXTRA_INFO`. + + See also: :py:class:`monai.transform.inverse.InvertibleTransform` Args: trans_info: transforms inverse information list, contains context of every invertible transform. diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py index ee9a967da1..46b9115c7a 100644 --- a/monai/visualize/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -152,7 +152,7 @@ def __init__( upsampler: Optional[Callable] = default_upsampler, verbose: bool = True, ) -> None: - """Occlusion sensitivitiy constructor. + """Occlusion sensitivity constructor. Args: nn_module: Classification model to use for inference diff --git a/tests/test_handler_mean_dice.py b/tests/test_handler_mean_dice.py index d15b549d86..648ffe91ae 100644 --- a/tests/test_handler_mean_dice.py +++ b/tests/test_handler_mean_dice.py @@ -39,8 +39,8 @@ def _val_func(engine, batch): y = torch.Tensor([[[0], [1]], [[0], [1]]]) dice_metric.update([y_pred, y]) - y_pred = torch.Tensor([[[0], [1]], [[1], [0]]]) - y = torch.Tensor([[[0], [1]], [[1], [0]]]) + y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] + y = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] dice_metric.update([y_pred, y]) avg_dice = dice_metric.compute() diff --git a/tests/test_handler_transform_inverter.py b/tests/test_handler_transform_inverter.py index 87414319cf..087839a75e 100644 --- a/tests/test_handler_transform_inverter.py +++ b/tests/test_handler_transform_inverter.py @@ -17,6 +17,7 @@ from ignite.engine import Engine from monai.data import CacheDataset, DataLoader, create_test_image_3d +from monai.engines.utils import IterationEvents from monai.handlers import TransformInverter from monai.transforms import ( AddChanneld, @@ -47,13 +48,13 @@ def test_invert(self): [ LoadImaged(KEYS), AddChanneld(KEYS), - ScaleIntensityd(KEYS, minv=1, maxv=10), + ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), - RandAffined(KEYS, prob=0.5, rotate_range=np.pi), + RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS), CastToTyped(KEYS, dtype=torch.uint8), @@ -70,19 +71,38 @@ def test_invert(self): # set up engine def _train_func(engine, batch): self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) - return batch + engine.state.output = batch + engine.fire_event(IterationEvents.MODEL_COMPLETED) + return engine.state.output engine = Engine(_train_func) + engine.register_events(*IterationEvents) # set up testing handler - TransformInverter(transform=transform, loader=loader, output_key="image", nearest_interp=True).attach(engine) + TransformInverter( + transform=transform, + loader=loader, + output_keys=["image", "label"], + batch_keys="label", + nearest_interp=True, + num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, + ).attach(engine) engine.run(loader, max_epochs=1) set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) - for i in engine.state.output["image_inverted"]: - np.testing.assert_allclose(i.astype(np.uint8).astype(np.float32), i, rtol=1e-4) + self.assertTupleEqual(engine.state.output["label"].shape, (2, 1, 100, 100, 100)) + for i in engine.state.output["image_inverted"] + engine.state.output["label_inverted"]: + torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) + # check labels match + reverted = engine.state.output["label_inverted"][-1].detach().cpu().numpy()[0].astype(np.int32) + original = LoadImaged(KEYS)(data[-1])["label"] + n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) + reverted_name = engine.state.output["label_meta_dict"]["filename_or_obj"][-1] + original_name = data[-1]["label"] + self.assertEqual(reverted_name, original_name) + self.assertTrue((reverted.size - n_good) in (0, 23641), "diff. in two possible values") if __name__ == "__main__": diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index db7580bf86..00d097b2b6 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -160,7 +160,7 @@ def attach(self, engine): engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed) engine.add_event_handler(IterationEvents.LOSS_COMPLETED, self._loss_completed) engine.add_event_handler(IterationEvents.BACKWARD_COMPLETED, self._backward_completed) - engine.add_event_handler(IterationEvents.OPTIMIZER_COMPLETED, self._optimizer_completed) + engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self._model_completed) def _forward_completed(self, engine): pass @@ -171,7 +171,7 @@ def _loss_completed(self, engine): def _backward_completed(self, engine): pass - def _optimizer_completed(self, engine): + def _model_completed(self, engine): pass train_handlers = [