diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index da9b23ce57..2e733c4f6c 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -211,7 +211,8 @@ def __call__( raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") img = convert_to_tensor(img, track_meta=get_track_meta()) img_t, *_ = convert_data_type(img, torch.Tensor) - if argmax or self.argmax: + argmax = self.argmax if argmax is None else argmax + if argmax: img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True)) to_onehot = self.to_onehot if to_onehot is None else to_onehot