From 8b2a5f5ffbdf4acab968926b13ac158c10617278 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 17 Nov 2021 19:46:23 +0800 Subject: [PATCH 1/5] [DLMED] simiplify AsDiscrete transform Signed-off-by: Nic Ma --- monai/transforms/post/array.py | 63 +++++++++------------------------- 1 file changed, 17 insertions(+), 46 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 7cbb6aad44..444980478b 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -122,56 +122,36 @@ class AsDiscrete(Transform): Args: argmax: whether to execute argmax function on input data before transform. Defaults to ``False``. - to_onehot: whether to convert input data into the one-hot format. - Defaults to ``False``. - num_classes: the number of classes to convert to One-Hot format. + to_onehot: if not None, convert input data into the one-hot format with specified number of classes. + Defaults to ``None``. + threshold_values: if not None, threshold the float values to int number 0 or 1 with specified theashold. Defaults to ``None``. - threshold_values: whether threshold the float value to int number 0 or 1. - Defaults to ``False``. - logit_thresh: the threshold value for thresholding operation.. - Defaults to ``0.5``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. - .. deprecated:: 0.6.0 - ``n_classes`` is deprecated, use ``num_classes`` instead. - """ backend = [TransformBackends.TORCH] - @deprecated_arg("n_classes", since="0.6") def __init__( self, argmax: bool = False, - to_onehot: bool = False, - num_classes: Optional[int] = None, - threshold_values: bool = False, - logit_thresh: float = 0.5, + to_onehot: Optional[int] = None, + threshold_values: Optional[float] = None, rounding: Optional[str] = None, - n_classes: Optional[int] = None, ) -> None: - # in case the new num_classes is default but you still call deprecated n_classes - if n_classes is not None and num_classes is None: - num_classes = n_classes self.argmax = argmax self.to_onehot = to_onehot - self.num_classes = num_classes self.threshold_values = threshold_values - self.logit_thresh = logit_thresh self.rounding = rounding - @deprecated_arg("n_classes", since="0.6") def __call__( self, img: NdarrayOrTensor, argmax: Optional[bool] = None, - to_onehot: Optional[bool] = None, - num_classes: Optional[int] = None, - threshold_values: Optional[bool] = None, - logit_thresh: Optional[float] = None, + to_onehot: Optional[int] = None, + threshold_values: Optional[float] = None, rounding: Optional[str] = None, - n_classes: Optional[int] = None, ) -> NdarrayOrTensor: """ Args: @@ -179,37 +159,28 @@ def __call__( will automatically add it. argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. - to_onehot: whether to convert input data into the one-hot format. + to_onehot: if not None, convert input data into the one-hot format with specified number of classes. Defaults to ``self.to_onehot``. - num_classes: the number of classes to convert to One-Hot format. - Defaults to ``self.num_classes``. - threshold_values: whether threshold the float value to int number 0 or 1. + threshold_values: if not None, threshold the float values to int number 0 or 1 with specified theashold. Defaults to ``self.threshold_values``. - logit_thresh: the threshold value for thresholding operation.. - Defaults to ``self.logit_thresh``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. - .. deprecated:: 0.6.0 - ``n_classes`` is deprecated, use ``num_classes`` instead. - """ - # in case the new num_classes is default but you still call deprecated n_classes - if n_classes is not None and num_classes is None: - num_classes = n_classes img_t: torch.Tensor img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore if argmax or self.argmax: img_t = torch.argmax(img_t, dim=0, keepdim=True) - if to_onehot or self.to_onehot: - _nclasses = self.num_classes if num_classes is None else num_classes - if not isinstance(_nclasses, int): - raise AssertionError("One of self.num_classes or num_classes must be an integer") - img_t = one_hot(img_t, num_classes=_nclasses, dim=0) + to_onehot = to_onehot or self.to_onehot + if to_onehot is not None: + if not isinstance(to_onehot, int): + raise AssertionError("the number of classes for One-Hot must be an integer.") + img_t = one_hot(img_t, num_classes=to_onehot, dim=0) - if threshold_values or self.threshold_values: - img_t = img_t >= (self.logit_thresh if logit_thresh is None else logit_thresh) + threshold_values = threshold_values or self.threshold_values + if threshold_values is not None: + img_t = img_t >= threshold_values rounding = self.rounding if rounding is None else rounding if rounding is not None: From 17bb3d9af75893dcfb867416ce6bf438f0f7e673 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 17 Nov 2021 20:28:13 +0800 Subject: [PATCH 2/5] [DLMED] update tests Signed-off-by: Nic Ma --- monai/data/test_time_augmentation.py | 2 +- monai/transforms/post/array.py | 2 +- monai/transforms/post/dictionary.py | 33 ++++++------------- .../transforms/utils_create_transform_ims.py | 11 ++----- tests/test_as_discrete.py | 8 ++--- tests/test_as_discreted.py | 27 ++------------- tests/test_compute_roc_auc.py | 18 +++++----- tests/test_handler_decollate_batch.py | 2 +- tests/test_handler_post_processing.py | 2 +- tests/test_handler_rocauc.py | 2 +- tests/test_handler_rocauc_dist.py | 2 +- tests/test_integration_classification_2d.py | 2 +- tests/test_integration_fast_train.py | 4 +-- tests/test_integration_segmentation_3d.py | 4 +-- tests/test_integration_workflows.py | 6 ++-- tests/test_testtimeaugmentation.py | 2 +- 16 files changed, 43 insertions(+), 84 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 948e85e131..9c1a25d5b9 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -86,7 +86,7 @@ class TestTimeAugmentation: .. code-block:: python transform = RandAffined(keys, ...) - post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=0.5)]) tt_aug = TestTimeAugmentation( transform, batch_size=5, num_workers=0, inferrer_fn=lambda x: post_trans(model(x)), device=device diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 444980478b..67a591101b 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -25,7 +25,7 @@ from monai.transforms.transform import Transform from monai.transforms.utils import fill_holes, get_largest_connected_component_mask from monai.transforms.utils_pytorch_numpy_unification import unravel_index -from monai.utils import TransformBackends, convert_data_type, deprecated_arg, ensure_tuple, look_up_option +from monai.utils import TransformBackends, convert_data_type, ensure_tuple, look_up_option from monai.utils.type_conversion import convert_to_dst_type __all__ = [ diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 226f5953e2..88c11d8c45 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -38,7 +38,7 @@ from monai.transforms.transform import MapTransform from monai.transforms.utility.array import ToTensor from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode -from monai.utils import deprecated_arg, ensure_tuple, ensure_tuple_rep +from monai.utils import ensure_tuple, ensure_tuple_rep __all__ = [ "ActivationsD", @@ -128,18 +128,14 @@ class AsDiscreted(MapTransform): backend = AsDiscrete.backend - @deprecated_arg("n_classes", since="0.6") def __init__( self, keys: KeysCollection, argmax: Union[Sequence[bool], bool] = False, - to_onehot: Union[Sequence[bool], bool] = False, - num_classes: Optional[Union[Sequence[int], int]] = None, - threshold_values: Union[Sequence[bool], bool] = False, - logit_thresh: Union[Sequence[float], float] = 0.5, + to_onehot: Union[Sequence[Optional[int]], Optional[int]] = None, + threshold_values: Union[Sequence[Optional[float]], Optional[float]] = None, rounding: Union[Sequence[Optional[str]], Optional[str]] = None, allow_missing_keys: bool = False, - n_classes: Optional[int] = None, ) -> None: """ Args: @@ -147,14 +143,10 @@ def __init__( See also: :py:class:`monai.transforms.compose.MapTransform` argmax: whether to execute argmax function on input data before transform. it also can be a sequence of bool, each element corresponds to a key in ``keys``. - to_onehot: whether to convert input data into the one-hot format. Defaults to False. - it also can be a sequence of bool, each element corresponds to a key in ``keys``. - num_classes: the number of classes to convert to One-Hot format. it also can be a - sequence of int, each element corresponds to a key in ``keys``. - threshold_values: whether threshold the float value to int number 0 or 1, default is False. - it also can be a sequence of bool, each element corresponds to a key in ``keys``. - logit_thresh: the threshold value for thresholding operation, default is 0.5. - it also can be a sequence of float, each element corresponds to a key in ``keys``. + to_onehot: if not None, convert input data into the one-hot format with specified number of classes. + defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``. + threshold_values: if not None, threshold the float values to int number 0 or 1 with specified theashold. + defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. it also can be a sequence of str or None, each element corresponds to a key in ``keys``. @@ -164,24 +156,19 @@ def __init__( ``n_classes`` is deprecated, use ``num_classes`` instead. """ - # in case the new num_classes is default but you still call deprecated n_classes - if n_classes is not None and num_classes is None: - num_classes = n_classes super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys)) - self.num_classes = ensure_tuple_rep(num_classes, len(self.keys)) self.threshold_values = ensure_tuple_rep(threshold_values, len(self.keys)) - self.logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) self.rounding = ensure_tuple_rep(rounding, len(self.keys)) self.converter = AsDiscrete() def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key, argmax, to_onehot, num_classes, threshold_values, logit_thresh, rounding in self.key_iterator( - d, self.argmax, self.to_onehot, self.num_classes, self.threshold_values, self.logit_thresh, self.rounding + for key, argmax, to_onehot, threshold_values, rounding in self.key_iterator( + d, self.argmax, self.to_onehot, self.threshold_values, self.rounding ): - d[key] = self.converter(d[key], argmax, to_onehot, num_classes, threshold_values, logit_thresh, rounding) + d[key] = self.converter(d[key], argmax, to_onehot, threshold_values, rounding) return d diff --git a/monai/transforms/utils_create_transform_ims.py b/monai/transforms/utils_create_transform_ims.py index 84aaa348fe..2abbc0628b 100644 --- a/monai/transforms/utils_create_transform_ims.py +++ b/monai/transforms/utils_create_transform_ims.py @@ -640,15 +640,8 @@ def create_transform_im( create_transform_im(RandScaleCropd, dict(keys=keys, roi_scale=0.4), data) create_transform_im(CenterScaleCrop, dict(roi_scale=0.4), data) create_transform_im(CenterScaleCropd, dict(keys=keys, roi_scale=0.4), data) - create_transform_im( - AsDiscrete, dict(num_classes=2, threshold_values=True, logit_thresh=10), data, is_post=True, colorbar=True - ) - create_transform_im( - AsDiscreted, - dict(keys=CommonKeys.LABEL, num_classes=2, threshold_values=True, logit_thresh=10), - data, - is_post=True, - ) + create_transform_im(AsDiscrete, dict(to_onehot=2, threshold_values=10), data, is_post=True, colorbar=True) + create_transform_im(AsDiscreted, dict(keys=CommonKeys.LABEL, to_onehot=2, threshold_values=10), data, is_post=True) create_transform_im(LabelFilter, dict(applied_labels=(1, 2, 3, 4, 5, 6)), data, is_post=True) create_transform_im( LabelFilterd, dict(keys=CommonKeys.LABEL, applied_labels=(1, 2, 3, 4, 5, 6)), data, is_post=True diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index 75f6a38d3c..ecaed68f41 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -20,7 +20,7 @@ for p in TEST_NDARRAYS: TEST_CASES.append( [ - {"argmax": True, "to_onehot": False, "num_classes": None, "threshold_values": False, "logit_thresh": 0.5}, + {"argmax": True, "to_onehot": None, "threshold_values": 0.5}, p([[[0.0, 1.0]], [[2.0, 3.0]]]), p([[[1.0, 1.0]]]), (1, 1, 2), @@ -29,7 +29,7 @@ TEST_CASES.append( [ - {"argmax": True, "to_onehot": True, "num_classes": 2, "threshold_values": False, "logit_thresh": 0.5}, + {"argmax": True, "to_onehot": 2, "threshold_values": 0.5}, p([[[0.0, 1.0]], [[2.0, 3.0]]]), p([[[0.0, 0.0]], [[1.0, 1.0]]]), (2, 1, 2), @@ -38,14 +38,14 @@ TEST_CASES.append( [ - {"argmax": False, "to_onehot": False, "num_classes": None, "threshold_values": True, "logit_thresh": 0.6}, + {"argmax": False, "to_onehot": None, "threshold_values": 0.6}, p([[[0.0, 1.0], [2.0, 3.0]]]), p([[[0.0, 1.0], [1.0, 1.0]]]), (1, 2, 2), ] ) - TEST_CASES.append([{"argmax": False, "to_onehot": True, "num_classes": 3}, p(1), p([0.0, 1.0, 0.0]), (3,)]) + TEST_CASES.append([{"argmax": False, "to_onehot": 3}, p(1), p([0.0, 1.0, 0.0]), (3,)]) TEST_CASES.append( [{"rounding": "torchrounding"}, p([[[0.123, 1.345], [2.567, 3.789]]]), p([[[0.0, 1.0], [3.0, 4.0]]]), (1, 2, 2)] diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index dc160d5e46..d577fb9f22 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -20,14 +20,7 @@ for p in TEST_NDARRAYS: TEST_CASES.append( [ - { - "keys": ["pred", "label"], - "argmax": [True, False], - "to_onehot": True, - "num_classes": 2, - "threshold_values": False, - "logit_thresh": 0.5, - }, + {"keys": ["pred", "label"], "argmax": [True, False], "to_onehot": 2, "threshold_values": 0.5}, {"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": p([[[0, 1]]])}, {"pred": p([[[0.0, 0.0]], [[1.0, 1.0]]]), "label": p([[[1.0, 0.0]], [[0.0, 1.0]]])}, (2, 1, 2), @@ -36,14 +29,7 @@ TEST_CASES.append( [ - { - "keys": ["pred", "label"], - "argmax": False, - "to_onehot": False, - "num_classes": None, - "threshold_values": [True, False], - "logit_thresh": 0.6, - }, + {"keys": ["pred", "label"], "argmax": False, "to_onehot": None, "threshold_values": [0.6, None]}, {"pred": p([[[0.0, 1.0], [2.0, 3.0]]]), "label": p([[[0, 1], [1, 1]]])}, {"pred": p([[[0.0, 1.0], [1.0, 1.0]]]), "label": p([[[0.0, 1.0], [1.0, 1.0]]])}, (1, 2, 2), @@ -52,14 +38,7 @@ TEST_CASES.append( [ - { - "keys": ["pred"], - "argmax": True, - "to_onehot": True, - "num_classes": 2, - "threshold_values": False, - "logit_thresh": 0.5, - }, + {"keys": ["pred"], "argmax": True, "to_onehot": 2, "threshold_values": 0.5}, {"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]])}, {"pred": p([[[0.0, 0.0]], [[1.0, 1.0]]])}, (2, 1, 2), diff --git a/tests/test_compute_roc_auc.py b/tests/test_compute_roc_auc.py index bfecb4ce5b..02e2f2b24f 100644 --- a/tests/test_compute_roc_auc.py +++ b/tests/test_compute_roc_auc.py @@ -23,7 +23,7 @@ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), torch.tensor([[0], [1], [0], [1]]), True, - True, + 2, "macro", 0.75, ] @@ -32,20 +32,20 @@ torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([[0], [1], [0], [1]]), False, - False, + None, "macro", 0.875, ] -TEST_CASE_3 = [torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([0, 1, 0, 1]), False, False, "macro", 0.875] +TEST_CASE_3 = [torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.875] -TEST_CASE_4 = [torch.tensor([0.5, 0.5, 0.2, 8.3]), torch.tensor([0, 1, 0, 1]), False, False, "macro", 0.875] +TEST_CASE_4 = [torch.tensor([0.5, 0.5, 0.2, 8.3]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.875] TEST_CASE_5 = [ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), torch.tensor([[0], [1], [0], [1]]), True, - True, + 2, "none", [0.75, 0.75], ] @@ -54,7 +54,7 @@ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), True, - False, + None, "weighted", 0.56667, ] @@ -63,7 +63,7 @@ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), True, - False, + None, "micro", 0.62, ] @@ -73,7 +73,7 @@ class TestComputeROCAUC(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) - y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, num_classes=2)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)]) y_pred = torch.stack([y_pred_trans(i) for i in decollate_batch(y_pred)], dim=0) y = torch.stack([y_trans(i) for i in decollate_batch(y)], dim=0) result = compute_roc_auc(y_pred=y_pred, y=y, average=average) @@ -82,7 +82,7 @@ def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) - y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, num_classes=2)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)]) y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)] y = [y_trans(i) for i in decollate_batch(y)] metric = ROCAUCMetric(average=average) diff --git a/tests/test_handler_decollate_batch.py b/tests/test_handler_decollate_batch.py index 8f0ffb2b5c..584fb33116 100644 --- a/tests/test_handler_decollate_batch.py +++ b/tests/test_handler_decollate_batch.py @@ -32,7 +32,7 @@ def test_compute(self): [ Activationsd(keys="pred", sigmoid=True), CopyItemsd(keys="filename", times=1, names="filename_bak"), - AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, num_classes=2), + AsDiscreted(keys="pred", threshold_values=0.5, to_onehot=2), ] ) ), diff --git a/tests/test_handler_post_processing.py b/tests/test_handler_post_processing.py index e9d57128cb..fd670113b7 100644 --- a/tests/test_handler_post_processing.py +++ b/tests/test_handler_post_processing.py @@ -26,7 +26,7 @@ "transform": Compose( [ CopyItemsd(keys="filename", times=1, names="filename_bak"), - AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, num_classes=2), + AsDiscreted(keys="pred", threshold_values=0.5, to_onehot=2), ] ), "event": "iteration_completed", diff --git a/tests/test_handler_rocauc.py b/tests/test_handler_rocauc.py index 5b80bc43eb..bd32922777 100644 --- a/tests/test_handler_rocauc.py +++ b/tests/test_handler_rocauc.py @@ -22,7 +22,7 @@ class TestHandlerROCAUC(unittest.TestCase): def test_compute(self): auc_metric = ROCAUC() act = Activations(softmax=True) - to_onehot = AsDiscrete(to_onehot=True, num_classes=2) + to_onehot = AsDiscrete(to_onehot=2) y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])] y = [torch.Tensor([0]), torch.Tensor([1])] diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py index 8316d4c4b6..0905816868 100644 --- a/tests/test_handler_rocauc_dist.py +++ b/tests/test_handler_rocauc_dist.py @@ -26,7 +26,7 @@ class DistributedROCAUC(DistTestCase): def test_compute(self): auc_metric = ROCAUC() act = Activations(softmax=True) - to_onehot = AsDiscrete(to_onehot=True, num_classes=2) + to_onehot = AsDiscrete(to_onehot=2) device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" if dist.get_rank() == 0: diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index 7a94780f82..cafad9dcf0 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -80,7 +80,7 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", [LoadImage(image_only=True), AddChannel(), Transpose(indices=[0, 2, 1]), ScaleIntensity(), ToTensor()] ) y_pred_trans = Compose([ToTensor(), Activations(softmax=True)]) - y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=True, num_classes=len(np.unique(train_y)))]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=len(np.unique(train_y)))]) auc_metric = ROCAUCMetric() # create train, val data loaders diff --git a/tests/test_integration_fast_train.py b/tests/test_integration_fast_train.py index 9fd37a0897..b2706dbb47 100644 --- a/tests/test_integration_fast_train.py +++ b/tests/test_integration_fast_train.py @@ -170,8 +170,8 @@ def test_train_timing(self): optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() - post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=2)]) - post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)]) + post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) + post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index 215a5b3f9a..197dfeecd5 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -95,7 +95,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) - val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=0.5)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) # create UNet, DiceLoss and Adam optimizer @@ -195,7 +195,7 @@ def run_inference_test(root_dir, device="cuda:0"): val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) - val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=0.5)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) model = UNet( diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 0ffef2935b..b3d80c8b84 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -114,7 +114,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=True), + AsDiscreted(keys="pred", threshold_values=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) @@ -155,7 +155,7 @@ def _forward_completed(self, engine): [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=True), + AsDiscreted(keys="pred", threshold_values=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) @@ -242,7 +242,7 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=True), + AsDiscreted(keys="pred", threshold_values=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch` SaveImaged(keys="pred", meta_keys="image_meta_dict", output_dir=root_dir, output_postfix="seg_transform"), diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index e7303ca524..568571b380 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -122,7 +122,7 @@ def test_test_time_augmentation(self): epoch_loss /= len(train_loader) - post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=0.5)]) def inferrer_fn(x): return post_trans(model(x)) From 70919d3a7d3ddce0f6c2dc0dce1e2794999694e7 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 17 Nov 2021 22:27:26 +0800 Subject: [PATCH 3/5] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/data/test_time_augmentation.py | 2 +- monai/transforms/post/array.py | 23 +++++++++++-------- monai/transforms/post/dictionary.py | 15 ++++++------ .../transforms/utils_create_transform_ims.py | 4 ++-- tests/test_as_discrete.py | 6 ++--- tests/test_as_discreted.py | 6 ++--- tests/test_handler_decollate_batch.py | 2 +- tests/test_handler_post_processing.py | 2 +- tests/test_integration_segmentation_3d.py | 4 ++-- tests/test_integration_workflows.py | 6 ++--- tests/test_testtimeaugmentation.py | 2 +- 11 files changed, 38 insertions(+), 34 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 9c1a25d5b9..4df3283ab6 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -86,7 +86,7 @@ class TestTimeAugmentation: .. code-block:: python transform = RandAffined(keys, ...) - post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=0.5)]) + post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) tt_aug = TestTimeAugmentation( transform, batch_size=5, num_workers=0, inferrer_fn=lambda x: post_trans(model(x)), device=device diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 67a591101b..71cd349fba 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -25,7 +25,7 @@ from monai.transforms.transform import Transform from monai.transforms.utils import fill_holes, get_largest_connected_component_mask from monai.transforms.utils_pytorch_numpy_unification import unravel_index -from monai.utils import TransformBackends, convert_data_type, ensure_tuple, look_up_option +from monai.utils import TransformBackends, convert_data_type, deprecated_arg, ensure_tuple, look_up_option from monai.utils.type_conversion import convert_to_dst_type __all__ = [ @@ -124,7 +124,7 @@ class AsDiscrete(Transform): Defaults to ``False``. to_onehot: if not None, convert input data into the one-hot format with specified number of classes. Defaults to ``None``. - threshold_values: if not None, threshold the float values to int number 0 or 1 with specified theashold. + threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold. Defaults to ``None``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. @@ -133,16 +133,19 @@ class AsDiscrete(Transform): backend = [TransformBackends.TORCH] + @deprecated_arg(name="threshold_values", new_name="threshold", since="0.7") def __init__( self, argmax: bool = False, to_onehot: Optional[int] = None, - threshold_values: Optional[float] = None, + threshold: Optional[float] = None, rounding: Optional[str] = None, ) -> None: self.argmax = argmax self.to_onehot = to_onehot - self.threshold_values = threshold_values + if threshold is True: + raise ValueError("`threshold_values=True` is deprecated, please use `threashold=value` instead.") + self.threshold = threshold self.rounding = rounding def __call__( @@ -150,7 +153,7 @@ def __call__( img: NdarrayOrTensor, argmax: Optional[bool] = None, to_onehot: Optional[int] = None, - threshold_values: Optional[float] = None, + threshold: Optional[float] = None, rounding: Optional[str] = None, ) -> NdarrayOrTensor: """ @@ -161,8 +164,8 @@ def __call__( Defaults to ``self.argmax``. to_onehot: if not None, convert input data into the one-hot format with specified number of classes. Defaults to ``self.to_onehot``. - threshold_values: if not None, threshold the float values to int number 0 or 1 with specified theashold. - Defaults to ``self.threshold_values``. + threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold value. + Defaults to ``self.threshold``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. @@ -178,9 +181,9 @@ def __call__( raise AssertionError("the number of classes for One-Hot must be an integer.") img_t = one_hot(img_t, num_classes=to_onehot, dim=0) - threshold_values = threshold_values or self.threshold_values - if threshold_values is not None: - img_t = img_t >= threshold_values + threshold = threshold or self.threshold + if threshold is not None: + img_t = img_t >= threshold rounding = self.rounding if rounding is None else rounding if rounding is not None: diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 88c11d8c45..e7119dc2e5 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -38,7 +38,7 @@ from monai.transforms.transform import MapTransform from monai.transforms.utility.array import ToTensor from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode -from monai.utils import ensure_tuple, ensure_tuple_rep +from monai.utils import deprecated_arg, ensure_tuple, ensure_tuple_rep __all__ = [ "ActivationsD", @@ -128,12 +128,13 @@ class AsDiscreted(MapTransform): backend = AsDiscrete.backend + @deprecated_arg(name="threshold_values", new_name="threshold", since="0.7") def __init__( self, keys: KeysCollection, argmax: Union[Sequence[bool], bool] = False, to_onehot: Union[Sequence[Optional[int]], Optional[int]] = None, - threshold_values: Union[Sequence[Optional[float]], Optional[float]] = None, + threshold: Union[Sequence[Optional[float]], Optional[float]] = None, rounding: Union[Sequence[Optional[str]], Optional[str]] = None, allow_missing_keys: bool = False, ) -> None: @@ -145,7 +146,7 @@ def __init__( it also can be a sequence of bool, each element corresponds to a key in ``keys``. to_onehot: if not None, convert input data into the one-hot format with specified number of classes. defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``. - threshold_values: if not None, threshold the float values to int number 0 or 1 with specified theashold. + threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold value. defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. it also can be a sequence of str or None, @@ -159,16 +160,16 @@ def __init__( super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys)) - self.threshold_values = ensure_tuple_rep(threshold_values, len(self.keys)) + self.threshold = ensure_tuple_rep(threshold, len(self.keys)) self.rounding = ensure_tuple_rep(rounding, len(self.keys)) self.converter = AsDiscrete() def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key, argmax, to_onehot, threshold_values, rounding in self.key_iterator( - d, self.argmax, self.to_onehot, self.threshold_values, self.rounding + for key, argmax, to_onehot, threshold, rounding in self.key_iterator( + d, self.argmax, self.to_onehot, self.threshold, self.rounding ): - d[key] = self.converter(d[key], argmax, to_onehot, threshold_values, rounding) + d[key] = self.converter(d[key], argmax, to_onehot, threshold, rounding) return d diff --git a/monai/transforms/utils_create_transform_ims.py b/monai/transforms/utils_create_transform_ims.py index 2abbc0628b..59d359639b 100644 --- a/monai/transforms/utils_create_transform_ims.py +++ b/monai/transforms/utils_create_transform_ims.py @@ -640,8 +640,8 @@ def create_transform_im( create_transform_im(RandScaleCropd, dict(keys=keys, roi_scale=0.4), data) create_transform_im(CenterScaleCrop, dict(roi_scale=0.4), data) create_transform_im(CenterScaleCropd, dict(keys=keys, roi_scale=0.4), data) - create_transform_im(AsDiscrete, dict(to_onehot=2, threshold_values=10), data, is_post=True, colorbar=True) - create_transform_im(AsDiscreted, dict(keys=CommonKeys.LABEL, to_onehot=2, threshold_values=10), data, is_post=True) + create_transform_im(AsDiscrete, dict(to_onehot=2, threshold=10), data, is_post=True, colorbar=True) + create_transform_im(AsDiscreted, dict(keys=CommonKeys.LABEL, to_onehot=2, threshold=10), data, is_post=True) create_transform_im(LabelFilter, dict(applied_labels=(1, 2, 3, 4, 5, 6)), data, is_post=True) create_transform_im( LabelFilterd, dict(keys=CommonKeys.LABEL, applied_labels=(1, 2, 3, 4, 5, 6)), data, is_post=True diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index ecaed68f41..e8db2052f0 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -20,7 +20,7 @@ for p in TEST_NDARRAYS: TEST_CASES.append( [ - {"argmax": True, "to_onehot": None, "threshold_values": 0.5}, + {"argmax": True, "to_onehot": None, "threshold": 0.5}, p([[[0.0, 1.0]], [[2.0, 3.0]]]), p([[[1.0, 1.0]]]), (1, 1, 2), @@ -29,7 +29,7 @@ TEST_CASES.append( [ - {"argmax": True, "to_onehot": 2, "threshold_values": 0.5}, + {"argmax": True, "to_onehot": 2, "threshold": 0.5}, p([[[0.0, 1.0]], [[2.0, 3.0]]]), p([[[0.0, 0.0]], [[1.0, 1.0]]]), (2, 1, 2), @@ -38,7 +38,7 @@ TEST_CASES.append( [ - {"argmax": False, "to_onehot": None, "threshold_values": 0.6}, + {"argmax": False, "to_onehot": None, "threshold": 0.6}, p([[[0.0, 1.0], [2.0, 3.0]]]), p([[[0.0, 1.0], [1.0, 1.0]]]), (1, 2, 2), diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index d577fb9f22..8532f84d33 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -20,7 +20,7 @@ for p in TEST_NDARRAYS: TEST_CASES.append( [ - {"keys": ["pred", "label"], "argmax": [True, False], "to_onehot": 2, "threshold_values": 0.5}, + {"keys": ["pred", "label"], "argmax": [True, False], "to_onehot": 2, "threshold": 0.5}, {"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": p([[[0, 1]]])}, {"pred": p([[[0.0, 0.0]], [[1.0, 1.0]]]), "label": p([[[1.0, 0.0]], [[0.0, 1.0]]])}, (2, 1, 2), @@ -29,7 +29,7 @@ TEST_CASES.append( [ - {"keys": ["pred", "label"], "argmax": False, "to_onehot": None, "threshold_values": [0.6, None]}, + {"keys": ["pred", "label"], "argmax": False, "to_onehot": None, "threshold": [0.6, None]}, {"pred": p([[[0.0, 1.0], [2.0, 3.0]]]), "label": p([[[0, 1], [1, 1]]])}, {"pred": p([[[0.0, 1.0], [1.0, 1.0]]]), "label": p([[[0.0, 1.0], [1.0, 1.0]]])}, (1, 2, 2), @@ -38,7 +38,7 @@ TEST_CASES.append( [ - {"keys": ["pred"], "argmax": True, "to_onehot": 2, "threshold_values": 0.5}, + {"keys": ["pred"], "argmax": True, "to_onehot": 2, "threshold": 0.5}, {"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]])}, {"pred": p([[[0.0, 0.0]], [[1.0, 1.0]]])}, (2, 1, 2), diff --git a/tests/test_handler_decollate_batch.py b/tests/test_handler_decollate_batch.py index 584fb33116..d2282a971f 100644 --- a/tests/test_handler_decollate_batch.py +++ b/tests/test_handler_decollate_batch.py @@ -32,7 +32,7 @@ def test_compute(self): [ Activationsd(keys="pred", sigmoid=True), CopyItemsd(keys="filename", times=1, names="filename_bak"), - AsDiscreted(keys="pred", threshold_values=0.5, to_onehot=2), + AsDiscreted(keys="pred", threshold=0.5, to_onehot=2), ] ) ), diff --git a/tests/test_handler_post_processing.py b/tests/test_handler_post_processing.py index fd670113b7..4b47ece063 100644 --- a/tests/test_handler_post_processing.py +++ b/tests/test_handler_post_processing.py @@ -26,7 +26,7 @@ "transform": Compose( [ CopyItemsd(keys="filename", times=1, names="filename_bak"), - AsDiscreted(keys="pred", threshold_values=0.5, to_onehot=2), + AsDiscreted(keys="pred", threshold=0.5, to_onehot=2), ] ), "event": "iteration_completed", diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index 197dfeecd5..8898bcdbf8 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -95,7 +95,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) - val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=0.5)]) + val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) # create UNet, DiceLoss and Adam optimizer @@ -195,7 +195,7 @@ def run_inference_test(root_dir, device="cuda:0"): val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) - val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=0.5)]) + val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) model = UNet( diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index b3d80c8b84..f188ab626f 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -114,7 +114,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=0.5), + AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) @@ -155,7 +155,7 @@ def _forward_completed(self, engine): [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=0.5), + AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) @@ -242,7 +242,7 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=0.5), + AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch` SaveImaged(keys="pred", meta_keys="image_meta_dict", output_dir=root_dir, output_postfix="seg_transform"), diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 568571b380..09a7f1c2ed 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -122,7 +122,7 @@ def test_test_time_augmentation(self): epoch_loss /= len(train_loader) - post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=0.5)]) + post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) def inferrer_fn(x): return post_trans(model(x)) From e348f84165651f995a8734bd480fe8b883c71931 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 18 Nov 2021 16:09:47 +0800 Subject: [PATCH 4/5] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/post/array.py | 44 +++++++++++++++++++++++++++-- monai/transforms/post/dictionary.py | 14 ++++++++- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 71cd349fba..683735940b 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -129,10 +129,21 @@ class AsDiscrete(Transform): rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. + .. deprecated:: 0.6.0 + ``n_classes`` is deprecated, use ``to_onehot`` instead. + + .. deprecated:: 0.7.0 + ``num_classes`` is deprecated, use ``to_onehot`` instead. + ``logit_thresh`` is deprecated, use ``threshold`` instead. + ``threshold_values`` is deprecated, use ``threshold`` instead. + """ backend = [TransformBackends.TORCH] + @deprecated_arg("n_classes", since="0.6") + @deprecated_arg("num_classes", since="0.7") + @deprecated_arg("logit_thresh", since="0.7") @deprecated_arg(name="threshold_values", new_name="threshold", since="0.7") def __init__( self, @@ -140,14 +151,26 @@ def __init__( to_onehot: Optional[int] = None, threshold: Optional[float] = None, rounding: Optional[str] = None, + n_classes: Optional[int] = None, + num_classes: Optional[int] = None, + logit_thresh: float = 0.5, + threshold_values: bool = False, ) -> None: self.argmax = argmax + if isinstance(to_onehot, bool): + raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") self.to_onehot = to_onehot - if threshold is True: - raise ValueError("`threshold_values=True` is deprecated, please use `threashold=value` instead.") + + if isinstance(threshold, bool): + raise ValueError("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.") self.threshold = threshold + self.rounding = rounding + @deprecated_arg("n_classes", since="0.6") + @deprecated_arg("num_classes", since="0.7") + @deprecated_arg("logit_thresh", since="0.7") + @deprecated_arg(name="threshold_values", new_name="threshold", since="0.7") def __call__( self, img: NdarrayOrTensor, @@ -155,6 +178,10 @@ def __call__( to_onehot: Optional[int] = None, threshold: Optional[float] = None, rounding: Optional[str] = None, + n_classes: Optional[int] = None, + num_classes: Optional[int] = None, + logit_thresh: Optional[float] = None, + threshold_values: Optional[bool] = None, ) -> NdarrayOrTensor: """ Args: @@ -169,7 +196,20 @@ def __call__( rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. + .. deprecated:: 0.6.0 + ``n_classes`` is deprecated, use ``to_onehot`` instead. + + .. deprecated:: 0.7.0 + ``num_classes`` is deprecated, use ``to_onehot`` instead. + ``logit_thresh`` is deprecated, use ``threshold`` instead. + ``threshold_values`` is deprecated, use ``threshold`` instead. + """ + if isinstance(to_onehot, bool): + raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") + if isinstance(threshold, bool): + raise ValueError("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.") + img_t: torch.Tensor img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore if argmax or self.argmax: diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index e7119dc2e5..8f97114a69 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -128,6 +128,9 @@ class AsDiscreted(MapTransform): backend = AsDiscrete.backend + @deprecated_arg("n_classes", since="0.6") + @deprecated_arg("num_classes", since="0.7") + @deprecated_arg("logit_thresh", since="0.7") @deprecated_arg(name="threshold_values", new_name="threshold", since="0.7") def __init__( self, @@ -137,6 +140,10 @@ def __init__( threshold: Union[Sequence[Optional[float]], Optional[float]] = None, rounding: Union[Sequence[Optional[str]], Optional[str]] = None, allow_missing_keys: bool = False, + n_classes: Optional[Union[Sequence[int], int]] = None, + num_classes: Optional[Union[Sequence[int], int]] = None, + logit_thresh: Union[Sequence[float], float] = 0.5, + threshold_values: Union[Sequence[bool], bool] = False, ) -> None: """ Args: @@ -154,7 +161,12 @@ def __init__( allow_missing_keys: don't raise exception if key is missing. .. deprecated:: 0.6.0 - ``n_classes`` is deprecated, use ``num_classes`` instead. + ``n_classes`` is deprecated, use ``to_onehot`` instead. + + .. deprecated:: 0.7.0 + ``num_classes`` is deprecated, use ``to_onehot`` instead. + ``logit_thresh`` is deprecated, use ``threshold`` instead. + ``threshold_values`` is deprecated, use ``threshold`` instead. """ super().__init__(keys, allow_missing_keys) From e8acfe8e4aec2a55b56efcb25d4929dd2dac41b9 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 18 Nov 2021 22:02:54 +0800 Subject: [PATCH 5/5] [DLMED] add examples Signed-off-by: Nic Ma --- monai/transforms/post/array.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 683735940b..0f9133037a 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -129,6 +129,20 @@ class AsDiscrete(Transform): rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. + Example: + + >>> transform = AsDiscrete(argmax=True) + >>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]]))) + # [[[1.0, 1.0]]] + + >>> transform = AsDiscrete(threshold=0.6) + >>> print(transform(np.array([[[0.0, 0.5], [0.8, 3.0]]]))) + # [[[0.0, 0.0], [1.0, 1.0]]] + + >>> transform = AsDiscrete(argmax=True, to_onehot=2, threshold=0.5) + >>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]]))) + # [[[0.0, 0.0]], [[1.0, 1.0]]] + .. deprecated:: 0.6.0 ``n_classes`` is deprecated, use ``to_onehot`` instead.