diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 1d81d91a16..28a172beac 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -87,7 +87,7 @@ class TestTimeAugmentation: .. code-block:: python transform = RandAffined(keys, ...) - post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) tt_aug = TestTimeAugmentation( transform, batch_size=5, num_workers=0, inferrer_fn=lambda x: post_trans(model(x)), device=device diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 7cbb6aad44..0f9133037a 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -122,56 +122,80 @@ class AsDiscrete(Transform): Args: argmax: whether to execute argmax function on input data before transform. Defaults to ``False``. - to_onehot: whether to convert input data into the one-hot format. - Defaults to ``False``. - num_classes: the number of classes to convert to One-Hot format. + to_onehot: if not None, convert input data into the one-hot format with specified number of classes. + Defaults to ``None``. + threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold. Defaults to ``None``. - threshold_values: whether threshold the float value to int number 0 or 1. - Defaults to ``False``. - logit_thresh: the threshold value for thresholding operation.. - Defaults to ``0.5``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. + Example: + + >>> transform = AsDiscrete(argmax=True) + >>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]]))) + # [[[1.0, 1.0]]] + + >>> transform = AsDiscrete(threshold=0.6) + >>> print(transform(np.array([[[0.0, 0.5], [0.8, 3.0]]]))) + # [[[0.0, 0.0], [1.0, 1.0]]] + + >>> transform = AsDiscrete(argmax=True, to_onehot=2, threshold=0.5) + >>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]]))) + # [[[0.0, 0.0]], [[1.0, 1.0]]] + .. deprecated:: 0.6.0 - ``n_classes`` is deprecated, use ``num_classes`` instead. + ``n_classes`` is deprecated, use ``to_onehot`` instead. + + .. deprecated:: 0.7.0 + ``num_classes`` is deprecated, use ``to_onehot`` instead. + ``logit_thresh`` is deprecated, use ``threshold`` instead. + ``threshold_values`` is deprecated, use ``threshold`` instead. """ backend = [TransformBackends.TORCH] @deprecated_arg("n_classes", since="0.6") + @deprecated_arg("num_classes", since="0.7") + @deprecated_arg("logit_thresh", since="0.7") + @deprecated_arg(name="threshold_values", new_name="threshold", since="0.7") def __init__( self, argmax: bool = False, - to_onehot: bool = False, - num_classes: Optional[int] = None, - threshold_values: bool = False, - logit_thresh: float = 0.5, + to_onehot: Optional[int] = None, + threshold: Optional[float] = None, rounding: Optional[str] = None, n_classes: Optional[int] = None, + num_classes: Optional[int] = None, + logit_thresh: float = 0.5, + threshold_values: bool = False, ) -> None: - # in case the new num_classes is default but you still call deprecated n_classes - if n_classes is not None and num_classes is None: - num_classes = n_classes self.argmax = argmax + if isinstance(to_onehot, bool): + raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") self.to_onehot = to_onehot - self.num_classes = num_classes - self.threshold_values = threshold_values - self.logit_thresh = logit_thresh + + if isinstance(threshold, bool): + raise ValueError("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.") + self.threshold = threshold + self.rounding = rounding @deprecated_arg("n_classes", since="0.6") + @deprecated_arg("num_classes", since="0.7") + @deprecated_arg("logit_thresh", since="0.7") + @deprecated_arg(name="threshold_values", new_name="threshold", since="0.7") def __call__( self, img: NdarrayOrTensor, argmax: Optional[bool] = None, - to_onehot: Optional[bool] = None, - num_classes: Optional[int] = None, - threshold_values: Optional[bool] = None, - logit_thresh: Optional[float] = None, + to_onehot: Optional[int] = None, + threshold: Optional[float] = None, rounding: Optional[str] = None, n_classes: Optional[int] = None, + num_classes: Optional[int] = None, + logit_thresh: Optional[float] = None, + threshold_values: Optional[bool] = None, ) -> NdarrayOrTensor: """ Args: @@ -179,37 +203,41 @@ def __call__( will automatically add it. argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. - to_onehot: whether to convert input data into the one-hot format. + to_onehot: if not None, convert input data into the one-hot format with specified number of classes. Defaults to ``self.to_onehot``. - num_classes: the number of classes to convert to One-Hot format. - Defaults to ``self.num_classes``. - threshold_values: whether threshold the float value to int number 0 or 1. - Defaults to ``self.threshold_values``. - logit_thresh: the threshold value for thresholding operation.. - Defaults to ``self.logit_thresh``. + threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold value. + Defaults to ``self.threshold``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. .. deprecated:: 0.6.0 - ``n_classes`` is deprecated, use ``num_classes`` instead. + ``n_classes`` is deprecated, use ``to_onehot`` instead. + + .. deprecated:: 0.7.0 + ``num_classes`` is deprecated, use ``to_onehot`` instead. + ``logit_thresh`` is deprecated, use ``threshold`` instead. + ``threshold_values`` is deprecated, use ``threshold`` instead. """ - # in case the new num_classes is default but you still call deprecated n_classes - if n_classes is not None and num_classes is None: - num_classes = n_classes + if isinstance(to_onehot, bool): + raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") + if isinstance(threshold, bool): + raise ValueError("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.") + img_t: torch.Tensor img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore if argmax or self.argmax: img_t = torch.argmax(img_t, dim=0, keepdim=True) - if to_onehot or self.to_onehot: - _nclasses = self.num_classes if num_classes is None else num_classes - if not isinstance(_nclasses, int): - raise AssertionError("One of self.num_classes or num_classes must be an integer") - img_t = one_hot(img_t, num_classes=_nclasses, dim=0) + to_onehot = to_onehot or self.to_onehot + if to_onehot is not None: + if not isinstance(to_onehot, int): + raise AssertionError("the number of classes for One-Hot must be an integer.") + img_t = one_hot(img_t, num_classes=to_onehot, dim=0) - if threshold_values or self.threshold_values: - img_t = img_t >= (self.logit_thresh if logit_thresh is None else logit_thresh) + threshold = threshold or self.threshold + if threshold is not None: + img_t = img_t >= threshold rounding = self.rounding if rounding is None else rounding if rounding is not None: diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 226f5953e2..8f97114a69 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -129,17 +129,21 @@ class AsDiscreted(MapTransform): backend = AsDiscrete.backend @deprecated_arg("n_classes", since="0.6") + @deprecated_arg("num_classes", since="0.7") + @deprecated_arg("logit_thresh", since="0.7") + @deprecated_arg(name="threshold_values", new_name="threshold", since="0.7") def __init__( self, keys: KeysCollection, argmax: Union[Sequence[bool], bool] = False, - to_onehot: Union[Sequence[bool], bool] = False, - num_classes: Optional[Union[Sequence[int], int]] = None, - threshold_values: Union[Sequence[bool], bool] = False, - logit_thresh: Union[Sequence[float], float] = 0.5, + to_onehot: Union[Sequence[Optional[int]], Optional[int]] = None, + threshold: Union[Sequence[Optional[float]], Optional[float]] = None, rounding: Union[Sequence[Optional[str]], Optional[str]] = None, allow_missing_keys: bool = False, - n_classes: Optional[int] = None, + n_classes: Optional[Union[Sequence[int], int]] = None, + num_classes: Optional[Union[Sequence[int], int]] = None, + logit_thresh: Union[Sequence[float], float] = 0.5, + threshold_values: Union[Sequence[bool], bool] = False, ) -> None: """ Args: @@ -147,41 +151,37 @@ def __init__( See also: :py:class:`monai.transforms.compose.MapTransform` argmax: whether to execute argmax function on input data before transform. it also can be a sequence of bool, each element corresponds to a key in ``keys``. - to_onehot: whether to convert input data into the one-hot format. Defaults to False. - it also can be a sequence of bool, each element corresponds to a key in ``keys``. - num_classes: the number of classes to convert to One-Hot format. it also can be a - sequence of int, each element corresponds to a key in ``keys``. - threshold_values: whether threshold the float value to int number 0 or 1, default is False. - it also can be a sequence of bool, each element corresponds to a key in ``keys``. - logit_thresh: the threshold value for thresholding operation, default is 0.5. - it also can be a sequence of float, each element corresponds to a key in ``keys``. + to_onehot: if not None, convert input data into the one-hot format with specified number of classes. + defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``. + threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold value. + defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. it also can be a sequence of str or None, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. .. deprecated:: 0.6.0 - ``n_classes`` is deprecated, use ``num_classes`` instead. + ``n_classes`` is deprecated, use ``to_onehot`` instead. + + .. deprecated:: 0.7.0 + ``num_classes`` is deprecated, use ``to_onehot`` instead. + ``logit_thresh`` is deprecated, use ``threshold`` instead. + ``threshold_values`` is deprecated, use ``threshold`` instead. """ - # in case the new num_classes is default but you still call deprecated n_classes - if n_classes is not None and num_classes is None: - num_classes = n_classes super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys)) - self.num_classes = ensure_tuple_rep(num_classes, len(self.keys)) - self.threshold_values = ensure_tuple_rep(threshold_values, len(self.keys)) - self.logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) + self.threshold = ensure_tuple_rep(threshold, len(self.keys)) self.rounding = ensure_tuple_rep(rounding, len(self.keys)) self.converter = AsDiscrete() def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key, argmax, to_onehot, num_classes, threshold_values, logit_thresh, rounding in self.key_iterator( - d, self.argmax, self.to_onehot, self.num_classes, self.threshold_values, self.logit_thresh, self.rounding + for key, argmax, to_onehot, threshold, rounding in self.key_iterator( + d, self.argmax, self.to_onehot, self.threshold, self.rounding ): - d[key] = self.converter(d[key], argmax, to_onehot, num_classes, threshold_values, logit_thresh, rounding) + d[key] = self.converter(d[key], argmax, to_onehot, threshold, rounding) return d diff --git a/monai/transforms/utils_create_transform_ims.py b/monai/transforms/utils_create_transform_ims.py index 84aaa348fe..59d359639b 100644 --- a/monai/transforms/utils_create_transform_ims.py +++ b/monai/transforms/utils_create_transform_ims.py @@ -640,15 +640,8 @@ def create_transform_im( create_transform_im(RandScaleCropd, dict(keys=keys, roi_scale=0.4), data) create_transform_im(CenterScaleCrop, dict(roi_scale=0.4), data) create_transform_im(CenterScaleCropd, dict(keys=keys, roi_scale=0.4), data) - create_transform_im( - AsDiscrete, dict(num_classes=2, threshold_values=True, logit_thresh=10), data, is_post=True, colorbar=True - ) - create_transform_im( - AsDiscreted, - dict(keys=CommonKeys.LABEL, num_classes=2, threshold_values=True, logit_thresh=10), - data, - is_post=True, - ) + create_transform_im(AsDiscrete, dict(to_onehot=2, threshold=10), data, is_post=True, colorbar=True) + create_transform_im(AsDiscreted, dict(keys=CommonKeys.LABEL, to_onehot=2, threshold=10), data, is_post=True) create_transform_im(LabelFilter, dict(applied_labels=(1, 2, 3, 4, 5, 6)), data, is_post=True) create_transform_im( LabelFilterd, dict(keys=CommonKeys.LABEL, applied_labels=(1, 2, 3, 4, 5, 6)), data, is_post=True diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index 75f6a38d3c..e8db2052f0 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -20,7 +20,7 @@ for p in TEST_NDARRAYS: TEST_CASES.append( [ - {"argmax": True, "to_onehot": False, "num_classes": None, "threshold_values": False, "logit_thresh": 0.5}, + {"argmax": True, "to_onehot": None, "threshold": 0.5}, p([[[0.0, 1.0]], [[2.0, 3.0]]]), p([[[1.0, 1.0]]]), (1, 1, 2), @@ -29,7 +29,7 @@ TEST_CASES.append( [ - {"argmax": True, "to_onehot": True, "num_classes": 2, "threshold_values": False, "logit_thresh": 0.5}, + {"argmax": True, "to_onehot": 2, "threshold": 0.5}, p([[[0.0, 1.0]], [[2.0, 3.0]]]), p([[[0.0, 0.0]], [[1.0, 1.0]]]), (2, 1, 2), @@ -38,14 +38,14 @@ TEST_CASES.append( [ - {"argmax": False, "to_onehot": False, "num_classes": None, "threshold_values": True, "logit_thresh": 0.6}, + {"argmax": False, "to_onehot": None, "threshold": 0.6}, p([[[0.0, 1.0], [2.0, 3.0]]]), p([[[0.0, 1.0], [1.0, 1.0]]]), (1, 2, 2), ] ) - TEST_CASES.append([{"argmax": False, "to_onehot": True, "num_classes": 3}, p(1), p([0.0, 1.0, 0.0]), (3,)]) + TEST_CASES.append([{"argmax": False, "to_onehot": 3}, p(1), p([0.0, 1.0, 0.0]), (3,)]) TEST_CASES.append( [{"rounding": "torchrounding"}, p([[[0.123, 1.345], [2.567, 3.789]]]), p([[[0.0, 1.0], [3.0, 4.0]]]), (1, 2, 2)] diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index dc160d5e46..8532f84d33 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -20,14 +20,7 @@ for p in TEST_NDARRAYS: TEST_CASES.append( [ - { - "keys": ["pred", "label"], - "argmax": [True, False], - "to_onehot": True, - "num_classes": 2, - "threshold_values": False, - "logit_thresh": 0.5, - }, + {"keys": ["pred", "label"], "argmax": [True, False], "to_onehot": 2, "threshold": 0.5}, {"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": p([[[0, 1]]])}, {"pred": p([[[0.0, 0.0]], [[1.0, 1.0]]]), "label": p([[[1.0, 0.0]], [[0.0, 1.0]]])}, (2, 1, 2), @@ -36,14 +29,7 @@ TEST_CASES.append( [ - { - "keys": ["pred", "label"], - "argmax": False, - "to_onehot": False, - "num_classes": None, - "threshold_values": [True, False], - "logit_thresh": 0.6, - }, + {"keys": ["pred", "label"], "argmax": False, "to_onehot": None, "threshold": [0.6, None]}, {"pred": p([[[0.0, 1.0], [2.0, 3.0]]]), "label": p([[[0, 1], [1, 1]]])}, {"pred": p([[[0.0, 1.0], [1.0, 1.0]]]), "label": p([[[0.0, 1.0], [1.0, 1.0]]])}, (1, 2, 2), @@ -52,14 +38,7 @@ TEST_CASES.append( [ - { - "keys": ["pred"], - "argmax": True, - "to_onehot": True, - "num_classes": 2, - "threshold_values": False, - "logit_thresh": 0.5, - }, + {"keys": ["pred"], "argmax": True, "to_onehot": 2, "threshold": 0.5}, {"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]])}, {"pred": p([[[0.0, 0.0]], [[1.0, 1.0]]])}, (2, 1, 2), diff --git a/tests/test_compute_roc_auc.py b/tests/test_compute_roc_auc.py index bfecb4ce5b..02e2f2b24f 100644 --- a/tests/test_compute_roc_auc.py +++ b/tests/test_compute_roc_auc.py @@ -23,7 +23,7 @@ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), torch.tensor([[0], [1], [0], [1]]), True, - True, + 2, "macro", 0.75, ] @@ -32,20 +32,20 @@ torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([[0], [1], [0], [1]]), False, - False, + None, "macro", 0.875, ] -TEST_CASE_3 = [torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([0, 1, 0, 1]), False, False, "macro", 0.875] +TEST_CASE_3 = [torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.875] -TEST_CASE_4 = [torch.tensor([0.5, 0.5, 0.2, 8.3]), torch.tensor([0, 1, 0, 1]), False, False, "macro", 0.875] +TEST_CASE_4 = [torch.tensor([0.5, 0.5, 0.2, 8.3]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.875] TEST_CASE_5 = [ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), torch.tensor([[0], [1], [0], [1]]), True, - True, + 2, "none", [0.75, 0.75], ] @@ -54,7 +54,7 @@ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), True, - False, + None, "weighted", 0.56667, ] @@ -63,7 +63,7 @@ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), True, - False, + None, "micro", 0.62, ] @@ -73,7 +73,7 @@ class TestComputeROCAUC(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) - y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, num_classes=2)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)]) y_pred = torch.stack([y_pred_trans(i) for i in decollate_batch(y_pred)], dim=0) y = torch.stack([y_trans(i) for i in decollate_batch(y)], dim=0) result = compute_roc_auc(y_pred=y_pred, y=y, average=average) @@ -82,7 +82,7 @@ def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) - y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, num_classes=2)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)]) y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)] y = [y_trans(i) for i in decollate_batch(y)] metric = ROCAUCMetric(average=average) diff --git a/tests/test_handler_decollate_batch.py b/tests/test_handler_decollate_batch.py index 8f0ffb2b5c..d2282a971f 100644 --- a/tests/test_handler_decollate_batch.py +++ b/tests/test_handler_decollate_batch.py @@ -32,7 +32,7 @@ def test_compute(self): [ Activationsd(keys="pred", sigmoid=True), CopyItemsd(keys="filename", times=1, names="filename_bak"), - AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, num_classes=2), + AsDiscreted(keys="pred", threshold=0.5, to_onehot=2), ] ) ), diff --git a/tests/test_handler_post_processing.py b/tests/test_handler_post_processing.py index e9d57128cb..4b47ece063 100644 --- a/tests/test_handler_post_processing.py +++ b/tests/test_handler_post_processing.py @@ -26,7 +26,7 @@ "transform": Compose( [ CopyItemsd(keys="filename", times=1, names="filename_bak"), - AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, num_classes=2), + AsDiscreted(keys="pred", threshold=0.5, to_onehot=2), ] ), "event": "iteration_completed", diff --git a/tests/test_handler_rocauc.py b/tests/test_handler_rocauc.py index 5b80bc43eb..bd32922777 100644 --- a/tests/test_handler_rocauc.py +++ b/tests/test_handler_rocauc.py @@ -22,7 +22,7 @@ class TestHandlerROCAUC(unittest.TestCase): def test_compute(self): auc_metric = ROCAUC() act = Activations(softmax=True) - to_onehot = AsDiscrete(to_onehot=True, num_classes=2) + to_onehot = AsDiscrete(to_onehot=2) y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])] y = [torch.Tensor([0]), torch.Tensor([1])] diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py index 8316d4c4b6..0905816868 100644 --- a/tests/test_handler_rocauc_dist.py +++ b/tests/test_handler_rocauc_dist.py @@ -26,7 +26,7 @@ class DistributedROCAUC(DistTestCase): def test_compute(self): auc_metric = ROCAUC() act = Activations(softmax=True) - to_onehot = AsDiscrete(to_onehot=True, num_classes=2) + to_onehot = AsDiscrete(to_onehot=2) device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" if dist.get_rank() == 0: diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index 7a94780f82..cafad9dcf0 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -80,7 +80,7 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", [LoadImage(image_only=True), AddChannel(), Transpose(indices=[0, 2, 1]), ScaleIntensity(), ToTensor()] ) y_pred_trans = Compose([ToTensor(), Activations(softmax=True)]) - y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=True, num_classes=len(np.unique(train_y)))]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=len(np.unique(train_y)))]) auc_metric = ROCAUCMetric() # create train, val data loaders diff --git a/tests/test_integration_fast_train.py b/tests/test_integration_fast_train.py index 9fd37a0897..b2706dbb47 100644 --- a/tests/test_integration_fast_train.py +++ b/tests/test_integration_fast_train.py @@ -170,8 +170,8 @@ def test_train_timing(self): optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() - post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=2)]) - post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)]) + post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) + post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index 215a5b3f9a..8898bcdbf8 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -95,7 +95,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) - val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) # create UNet, DiceLoss and Adam optimizer @@ -195,7 +195,7 @@ def run_inference_test(root_dir, device="cuda:0"): val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) - val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) model = UNet( diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 4d37d51e83..7018c53240 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -114,7 +114,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=True), + AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) @@ -155,7 +155,7 @@ def _forward_completed(self, engine): [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=True), + AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) @@ -242,7 +242,7 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=True), + AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch` SaveImaged(keys="pred", meta_keys="image_meta_dict", output_dir=root_dir, output_postfix="seg_transform"), diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index e7303ca524..09a7f1c2ed 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -122,7 +122,7 @@ def test_test_time_augmentation(self): epoch_loss /= len(train_loader) - post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) def inferrer_fn(x): return post_trans(model(x))