Skip to content
2 changes: 2 additions & 0 deletions .github/workflows/cron.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ name: crons
on:
schedule:
- cron: "0 2 * * *" # at 02:00 UTC
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

jobs:
cron-gpu:
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ jobs:
python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))"
python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))'
BUILD_MONAI=1 ./runtests.sh --unittests --net
BUILD_MONAI=1 ./runtests.sh --net
BUILD_MONAI=1 ./runtests.sh --unittests
if pgrep python; then pkill python; fi
shell: bash
- name: Add reaction
Expand Down
23 changes: 17 additions & 6 deletions monai/data/inverse_batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,32 @@ def _transform(self, index: int) -> Dict[Hashable, np.ndarray]:


class BatchInverseTransform(Transform):
"""Perform inverse on a batch of data. This is useful if you have inferred a batch of images and want to invert them all."""
"""
Perform inverse on a batch of data. This is useful if you have inferred a batch of images and want to invert
them all.
"""

def __init__(
self, transform: InvertibleTransform, loader: TorchDataLoader, collate_fn: Optional[Callable] = no_collation
self,
transform: InvertibleTransform,
loader: TorchDataLoader,
collate_fn: Optional[Callable] = no_collation,
num_workers: Optional[int] = 0,
) -> None:
"""
Args:
transform: a callable data transform on input data.
loader: data loader used to generate the batch of data.
collate_fn: how to collate data after inverse transformations. Default won't do any collation, so the output will be a
list of size batch size.
loader: data loader used to run `transforms` and generate the batch of data.
collate_fn: how to collate data after inverse transformations.
default won't do any collation, so the output will be a list of size batch size.
num_workers: number of workers when run data loader for inverse transforms,
default to 0 as only run 1 iteration and multi-processing may be even slower.
if the transforms are really slow, set num_workers for multi-processing.
if set to `None`, use the `num_workers` of the transform data loader.
"""
self.transform = transform
self.batch_size = loader.batch_size
self.num_workers = loader.num_workers
self.num_workers = loader.num_workers if num_workers is None else num_workers
self.collate_fn = collate_fn
self.pad_collation_used = loader.collate_fn == pad_list_data_collate

Expand Down
20 changes: 12 additions & 8 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,17 +223,18 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict
inputs, targets, args, kwargs = batch

# put iteration outputs into engine.state
engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
# execute forward computation
with self.mode(self.network):
if self.amp:
with torch.cuda.amp.autocast():
output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs)
engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs)
else:
output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs)
engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs)
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
engine.fire_event(IterationEvents.MODEL_COMPLETED)

return output
return engine.state.output


class EnsembleEvaluator(Evaluator):
Expand Down Expand Up @@ -344,14 +345,17 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict
inputs, targets, args, kwargs = batch

# put iteration outputs into engine.state
engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
for idx, network in enumerate(self.networks):
with self.mode(network):
if self.amp:
with torch.cuda.amp.autocast():
output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)})
engine.state.output.update(
{self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}
)
else:
output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)})
engine.state.output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)})
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
engine.fire_event(IterationEvents.MODEL_COMPLETED)

return output
return engine.state.output
14 changes: 7 additions & 7 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,31 +157,31 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
else:
inputs, targets, args, kwargs = batch
# put iteration outputs into engine.state
engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}

def _compute_pred_loss():
output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs)
engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs)
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
output[Keys.LOSS] = self.loss_function(output[Keys.PRED], targets).mean()
engine.state.output[Keys.LOSS] = self.loss_function(engine.state.output[Keys.PRED], targets).mean()
engine.fire_event(IterationEvents.LOSS_COMPLETED)

self.network.train()
self.optimizer.zero_grad()
if self.amp and self.scaler is not None:
with torch.cuda.amp.autocast():
_compute_pred_loss()
self.scaler.scale(output[Keys.LOSS]).backward()
self.scaler.scale(engine.state.output[Keys.LOSS]).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
_compute_pred_loss()
output[Keys.LOSS].backward()
engine.state.output[Keys.LOSS].backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
self.optimizer.step()
engine.fire_event(IterationEvents.OPTIMIZER_COMPLETED)
engine.fire_event(IterationEvents.MODEL_COMPLETED)

return output
return engine.state.output


class GanTrainer(Trainer):
Expand Down
3 changes: 2 additions & 1 deletion monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ class IterationEvents(EventEnum):
`FORWARD_COMPLETED` is the Event when `network(image, label)` completed.
`LOSS_COMPLETED` is the Event when `loss(pred, label)` completed.
`BACKWARD_COMPLETED` is the Event when `loss.backward()` completed.
`MODEL_COMPLETED` is the Event when all the model related operations completed.

"""

FORWARD_COMPLETED = "forward_completed"
LOSS_COMPLETED = "loss_completed"
BACKWARD_COMPLETED = "backward_completed"
OPTIMIZER_COMPLETED = "optimizer_completed"
MODEL_COMPLETED = "model_completed"


class GanKeys:
Expand Down
4 changes: 2 additions & 2 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from monai.engines.utils import default_prepare_batch
from monai.engines.utils import IterationEvents, default_prepare_batch
from monai.transforms import apply_transform
from monai.utils import ensure_tuple, exact_version, optional_import

Expand Down Expand Up @@ -160,7 +160,7 @@ def _register_post_transforms(self, posttrans: Callable):

"""

@self.on(Events.ITERATION_COMPLETED)
@self.on(IterationEvents.MODEL_COMPLETED)
def run_post_transform(engine: Engine) -> None:
engine.state.output = apply_transform(posttrans, engine.state.output)

Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/earlystop_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ class EarlyStopHandler:
cumulative_delta: if True, `min_delta` defines an increase since the last `patience` reset, otherwise,
it defines an increase after the last event, default to False.
epoch_level: check early stopping for every epoch or every iteration of the attached engine,
`True` is epoch level, `False` is iteration level, defaut to epoch level.
`True` is epoch level, `False` is iteration level, default to epoch level.

Note:
If in distributed training and uses loss value of every iteration to detect earlystopping,
If in distributed training and uses loss value of every iteration to detect early stopping,
the values may be different in different ranks.
User may attach this handler to validator engine to detect validation metrics and stop the training,
in this case, the `score_function` is executed on validator engine and `trainer` is the trainer engine.
Expand Down
14 changes: 11 additions & 3 deletions monai/handlers/iteration_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,18 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
"""
if len(output) != 2:
raise ValueError(f"output must have length 2, got {len(output)}.")

y_pred, y = output
score = self.metric_fn(y_pred, y)
if isinstance(score, (tuple, list)):
score = score[0]

def _compute(y_pred, y):
score = self.metric_fn(y_pred, y)
return score[0] if isinstance(score, (tuple, list)) else score

if isinstance(y_pred, (list, tuple)) and isinstance(y, (list, tuple)):
# if a list of channel-first data, add batch dim and compute metric, then concat the scores
score = torch.cat([_compute(p_.unsqueeze(0), y_.unsqueeze(0)) for p_, y_ in zip(y_pred, y)], dim=0)
else:
score = _compute(y_pred, y)
self._scores.append(score.to(self._device))

def compute(self) -> Any:
Expand Down
88 changes: 51 additions & 37 deletions monai/handlers/transform_inverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union

from torch.utils.data import DataLoader as TorchDataLoader

from monai.data import BatchInverseTransform
from monai.data.utils import no_collation
from monai.engines.utils import CommonKeys
from monai.transforms import InvertibleTransform, allow_missing_keys_mode, convert_inverse_interp_mode
from monai.utils import InverseKeys, exact_version, optional_import
from monai.engines.utils import CommonKeys, IterationEvents
from monai.transforms import InvertibleTransform, ToTensor, allow_missing_keys_mode, convert_inverse_interp_mode
from monai.utils import InverseKeys, ensure_tuple, ensure_tuple_rep, exact_version, optional_import

Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
if TYPE_CHECKING:
Expand All @@ -29,68 +29,82 @@

class TransformInverter:
"""
Ignite handler to automatically invert all the pre-transforms that support `inverse`.
It takes `engine.state.output` as the input data and uses the transforms infomation from `engine.state.batch`.

Ignite handler to automatically invert `transforms`.
It takes `engine.state.output` as the input data and uses the transforms information from `engine.state.batch`.
The outputs are stored in `engine.state.output` with the `output_keys`.
"""

def __init__(
self,
transform: InvertibleTransform,
loader: TorchDataLoader,
output_keys: Union[str, Sequence[str]] = CommonKeys.PRED,
batch_keys: Union[str, Sequence[str]] = CommonKeys.IMAGE,
collate_fn: Optional[Callable] = no_collation,
batch_key: str = CommonKeys.IMAGE,
output_key: str = CommonKeys.PRED,
postfix: str = "inverted",
nearest_interp: bool = True,
postfix: str = "_inverted",
nearest_interp: Union[bool, Sequence[bool]] = True,
num_workers: Optional[int] = 0,
) -> None:
"""
Args:
transform: a callable data transform on input data.
loader: data loader used to generate the batch of data.
loader: data loader used to run transforms and generate the batch of data.
collate_fn: how to collate data after inverse transformations.
default won't do any collation, so the output will be a list of size batch size.
batch_key: the key of input data in `ignite.engine.batch`. will get the applied transforms
for this input data, then invert them for the model output, default to "image".
output_key: the key of model output in `ignite.engine.output`, invert transforms on it.
postfix: will save the inverted result into `ignite.engine.output` with key `{ouput_key}_{postfix}`.
nearest_interp: whether to use `nearest` interpolation mode when inverting spatial transforms,
default to `True`. if `False`, use the same interpolation mode as the original transform.
output_keys: the key of expected data in `ignite.engine.output`, invert transforms on it.
it also can be a list of keys, will invert transform for each of them. Default to "pred".
batch_keys: the key of input data in `ignite.engine.batch`. will get the applied transforms
for this input data, then invert them for the expected data with `output_keys`.
It can also be a list of keys, each matches to the `output_keys` data. default to "image".
postfix: will save the inverted result into `ignite.engine.output` with key `{output_key}{postfix}`.
nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
default to `True`. If `False`, use the same interpolation mode as the original transform.
it also can be a list of bool, each matches to the `output_keys` data.
num_workers: number of workers when run data loader for inverse transforms,
default to 0 as only run one iteration and multi-processing may be even slower.
Set to `None`, to use the `num_workers` of the input transform data loader.

"""
self.transform = transform
self.inverter = BatchInverseTransform(transform=transform, loader=loader, collate_fn=collate_fn)
self.batch_key = batch_key
self.output_key = output_key
self.inverter = BatchInverseTransform(
transform=transform,
loader=loader,
collate_fn=collate_fn,
num_workers=num_workers,
)
self.output_keys = ensure_tuple(output_keys)
self.batch_keys = ensure_tuple_rep(batch_keys, len(self.output_keys))
self.postfix = postfix
self.nearest_interp = nearest_interp
self.nearest_interp = ensure_tuple_rep(nearest_interp, len(self.output_keys))
self._totensor = ToTensor()

def attach(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
engine.add_event_handler(Events.ITERATION_COMPLETED, self)
engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self)

def __call__(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
transform_key = self.batch_key + InverseKeys.KEY_SUFFIX
if transform_key not in engine.state.batch:
warnings.warn("all the pre-transforms are not InvertibleTransform or no need to invert.")
return
for output_key, batch_key, nearest_interp in zip(self.output_keys, self.batch_keys, self.nearest_interp):
transform_key = batch_key + InverseKeys.KEY_SUFFIX
if transform_key not in engine.state.batch:
warnings.warn(f"all the transforms on `{batch_key}` are not InvertibleTransform.")
continue

transform_info = engine.state.batch[transform_key]
if self.nearest_interp:
convert_inverse_interp_mode(trans_info=transform_info, mode="nearest", align_corners=None)
transform_info = engine.state.batch[transform_key]
if nearest_interp:
convert_inverse_interp_mode(trans_info=transform_info, mode="nearest", align_corners=None)

segs_dict = {
self.batch_key: engine.state.output[self.output_key].detach().cpu(),
transform_key: transform_info,
}
segs_dict = {
batch_key: engine.state.output[output_key].detach().cpu(),
transform_key: transform_info,
}

with allow_missing_keys_mode(self.transform): # type: ignore
inverted_key = f"{self.output_key}_{self.postfix}"
engine.state.output[inverted_key] = [i[self.batch_key] for i in self.inverter(segs_dict)]
with allow_missing_keys_mode(self.transform): # type: ignore
inverted_key = f"{output_key}{self.postfix}"
engine.state.output[inverted_key] = [self._totensor(i[batch_key]) for i in self.inverter(segs_dict)]
2 changes: 1 addition & 1 deletion monai/networks/blocks/localnet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def forward(self, x, mid) -> torch.Tensor:
Args:
x: feature to be up-sampled, in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])
mid: mid-level feature saved during down-sampling,
in shape (batch, ``out_channels``, midsize_1, midsize_2, [midnsize_3])
in shape (batch, ``out_channels``, midsize_1, midsize_2, [midsize_3])

Raises:
ValueError: when ``midsize != insize * 2``
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/nets/dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class DynUNet(nn.Module):
(1, 2, 8, 6). The last two will be interpolated into (1, 2, 32, 24), and the stacked tensor
will has the shape (1, 3, 2, 8, 6).
When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss
one by one with the groud truth, then do a weighted average for all losses to achieve the final loss.
one by one with the ground truth, then do a weighted average for all losses to achieve the final loss.
(To be added: a corresponding tutorial link)

deep_supr_num: number of feature maps that will output during deep supervision head. The
Expand Down
10 changes: 5 additions & 5 deletions monai/networks/nets/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,19 +503,19 @@ def __init__(
)

# get network parameters
weight_coeff, depth_coeff, image_size, drpout_rate, drpconnect_rate = efficientnet_params[model_name]
weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name]

# create model and initialize random weights
model = super(EfficientNetBN, self).__init__(
super(EfficientNetBN, self).__init__(
blocks_args_str=blocks_args_str,
spatial_dims=spatial_dims,
in_channels=in_channels,
num_classes=num_classes,
width_coefficient=weight_coeff,
depth_coefficient=depth_coeff,
dropout_rate=drpout_rate,
dropout_rate=dropout_rate,
image_size=image_size,
drop_connect_rate=drpconnect_rate,
drop_connect_rate=dropconnect_rate,
)

# attempt to load pretrained
Expand Down Expand Up @@ -827,7 +827,7 @@ def _decode_block_string(block_string: str):
or (len(options["s"]) == 3 and options["s"][0] == options["s"][1] and options["s"][0] == options["s"][2])
)
if not stride_check:
raise ValueError("invalid stride option recieved")
raise ValueError("invalid stride option received")

return BlockArgs(
num_repeat=int(options["r"]),
Expand Down
Loading