Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions monai/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,9 @@ def get_properties(self, keys: Optional[Union[Sequence[str], str]] = None):
"""
if keys is None:
return self._properties
elif self._properties is not None:
if self._properties is not None:
return {key: self._properties[key] for key in ensure_tuple(keys)}
else:
return {}
return {}

def _generate_data_list(self, dataset_dir: str) -> List[Dict]:
section = "training" if self.section in ["training", "validation"] else "test"
Expand Down
44 changes: 28 additions & 16 deletions monai/csrc/resample/pushpull_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,25 @@ MONAI_NAMESPACE_DEVICE { // cpu
bool do_sgrad)
: dim(dim),
bound0(bound.size() > 0 ? bound[0] : BoundType::Replicate),
bound1(bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate),
bound1(
bound.size() > 1 ? bound[1]
: bound.size() > 0 ? bound[0]
: BoundType::Replicate),
bound2(
bound.size() > 2 ? bound[2]
: bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate),
bound.size() > 2 ? bound[2]
: bound.size() > 1 ? bound[1]
: bound.size() > 0 ? bound[0]
: BoundType::Replicate),
interpolation0(interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),
interpolation1(
interpolation.size() > 1 ? interpolation[1]
: interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),
interpolation.size() > 1 ? interpolation[1]
: interpolation.size() > 0 ? interpolation[0]
: InterpolationType::Linear),
interpolation2(
interpolation.size() > 2
? interpolation[2]
interpolation.size() > 2 ? interpolation[2]
: interpolation.size() > 1 ? interpolation[1]
: interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),
: interpolation.size() > 0 ? interpolation[0]
: InterpolationType::Linear),
extrapolate(extrapolate),
do_pull(do_pull),
do_push(do_push),
Expand All @@ -136,13 +142,14 @@ MONAI_NAMESPACE_DEVICE { // cpu
bound2(bound),
interpolation0(interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),
interpolation1(
interpolation.size() > 1 ? interpolation[1]
: interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),
interpolation.size() > 1 ? interpolation[1]
: interpolation.size() > 0 ? interpolation[0]
: InterpolationType::Linear),
interpolation2(
interpolation.size() > 2
? interpolation[2]
interpolation.size() > 2 ? interpolation[2]
: interpolation.size() > 1 ? interpolation[1]
: interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),
: interpolation.size() > 0 ? interpolation[0]
: InterpolationType::Linear),
extrapolate(extrapolate),
do_pull(do_pull),
do_push(do_push),
Expand All @@ -165,10 +172,15 @@ MONAI_NAMESPACE_DEVICE { // cpu
bool do_sgrad)
: dim(dim),
bound0(bound.size() > 0 ? bound[0] : BoundType::Replicate),
bound1(bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate),
bound1(
bound.size() > 1 ? bound[1]
: bound.size() > 0 ? bound[0]
: BoundType::Replicate),
bound2(
bound.size() > 2 ? bound[2]
: bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate),
bound.size() > 2 ? bound[2]
: bound.size() > 1 ? bound[1]
: bound.size() > 0 ? bound[0]
: BoundType::Replicate),
interpolation0(interpolation),
interpolation1(interpolation),
interpolation2(interpolation),
Expand Down
44 changes: 28 additions & 16 deletions monai/csrc/resample/pushpull_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,25 @@ MONAI_NAMESPACE_DEVICE { // cuda
bool do_sgrad)
: dim(dim),
bound0(bound.size() > 0 ? bound[0] : BoundType::Replicate),
bound1(bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate),
bound1(
bound.size() > 1 ? bound[1]
: bound.size() > 0 ? bound[0]
: BoundType::Replicate),
bound2(
bound.size() > 2 ? bound[2]
: bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate),
bound.size() > 2 ? bound[2]
: bound.size() > 1 ? bound[1]
: bound.size() > 0 ? bound[0]
: BoundType::Replicate),
interpolation0(interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),
interpolation1(
interpolation.size() > 1 ? interpolation[1]
: interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),
interpolation.size() > 1 ? interpolation[1]
: interpolation.size() > 0 ? interpolation[0]
: InterpolationType::Linear),
interpolation2(
interpolation.size() > 2
? interpolation[2]
interpolation.size() > 2 ? interpolation[2]
: interpolation.size() > 1 ? interpolation[1]
: interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),
: interpolation.size() > 0 ? interpolation[0]
: InterpolationType::Linear),
extrapolate(extrapolate),
do_pull(do_pull),
do_push(do_push),
Expand All @@ -133,13 +139,14 @@ MONAI_NAMESPACE_DEVICE { // cuda
bound2(bound),
interpolation0(interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),
interpolation1(
interpolation.size() > 1 ? interpolation[1]
: interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),
interpolation.size() > 1 ? interpolation[1]
: interpolation.size() > 0 ? interpolation[0]
: InterpolationType::Linear),
interpolation2(
interpolation.size() > 2
? interpolation[2]
interpolation.size() > 2 ? interpolation[2]
: interpolation.size() > 1 ? interpolation[1]
: interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),
: interpolation.size() > 0 ? interpolation[0]
: InterpolationType::Linear),
extrapolate(extrapolate),
do_pull(do_pull),
do_push(do_push),
Expand All @@ -162,10 +169,15 @@ MONAI_NAMESPACE_DEVICE { // cuda
bool do_sgrad)
: dim(dim),
bound0(bound.size() > 0 ? bound[0] : BoundType::Replicate),
bound1(bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate),
bound1(
bound.size() > 1 ? bound[1]
: bound.size() > 0 ? bound[0]
: BoundType::Replicate),
bound2(
bound.size() > 2 ? bound[2]
: bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate),
bound.size() > 2 ? bound[2]
: bound.size() > 1 ? bound[1]
: bound.size() > 0 ? bound[0]
: BoundType::Replicate),
interpolation0(interpolation),
interpolation1(interpolation),
interpolation2(interpolation),
Expand Down
3 changes: 1 addition & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,8 +699,7 @@ def _try_shutdown(self):
self._round = 0
self._replace_done = False
return True
else:
return False
return False

def shutdown(self):
"""
Expand Down
5 changes: 2 additions & 3 deletions monai/data/decathlon_datalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,12 @@ def _compute_path(base_dir, element):
"""
if isinstance(element, str):
return os.path.normpath(os.path.join(base_dir, element))
elif isinstance(element, list):
if isinstance(element, list):
for e in element:
if not isinstance(e, str):
raise TypeError(f"Every file path in element must be a str but got {type(element).__name__}.")
return [os.path.normpath(os.path.join(base_dir, e)) for e in element]
else:
raise TypeError(f"element must be one of (str, list) but is {type(element).__name__}.")
raise TypeError(f"element must be one of (str, list) but is {type(element).__name__}.")


def _append_paths(base_dir: str, is_segmentation: bool, items: List[Dict]) -> List[Dict]:
Expand Down
5 changes: 2 additions & 3 deletions monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,9 @@ def default_prepare_batch(
batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking),
batchdata[CommonKeys.LABEL].to(device=device, non_blocking=non_blocking),
)
elif GanKeys.REALS in batchdata:
if GanKeys.REALS in batchdata:
return batchdata[GanKeys.REALS].to(device=device, non_blocking=non_blocking)
else:
return batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking), None
return batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking), None


def default_make_latent(
Expand Down
5 changes: 2 additions & 3 deletions monai/handlers/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,5 @@ def compute(self):
"ConfusionMatrix metric must have at least one example before it can be computed."
)
return self._sum / self._num_examples
else:
confusion_matrix = torch.tensor([self._total_tp, self._total_fp, self._total_tn, self._total_fn])
return compute_confusion_matrix_metric(self.metric_name, confusion_matrix)
confusion_matrix = torch.tensor([self._total_tp, self._total_fp, self._total_tn, self._total_fn])
return compute_confusion_matrix_metric(self.metric_name, confusion_matrix)
57 changes: 27 additions & 30 deletions monai/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,15 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor):
confusion_matrix = compute_confusion_matrix_metric(self.metric_name, confusion_matrix)
f, not_nans = do_metric_reduction(confusion_matrix, self.reduction)
return f, not_nans
else:
if len(self.metric_name) < 1:
raise ValueError("the sequence should at least has on metric name.")
results = []
for metric_name in self.metric_name:
sub_confusion_matrix = compute_confusion_matrix_metric(metric_name, confusion_matrix)
f, not_nans = do_metric_reduction(sub_confusion_matrix, self.reduction)
results.append(f)
results.append(not_nans)
return results
if len(self.metric_name) < 1:
raise ValueError("the sequence should at least has on metric name.")
results = []
for metric_name in self.metric_name:
sub_confusion_matrix = compute_confusion_matrix_metric(metric_name, confusion_matrix)
f, not_nans = do_metric_reduction(sub_confusion_matrix, self.reduction)
results.append(f)
results.append(not_nans)
return results
else:
return confusion_matrix

Expand Down Expand Up @@ -264,8 +263,7 @@ def compute_confusion_matrix_metric(metric_name: str, confusion_matrix: torch.Te

if isinstance(denominator, torch.Tensor):
return torch.where(denominator != 0, numerator / denominator, nan_tensor)
else:
return numerator / denominator
return numerator / denominator


def check_confusion_matrix_metric_name(metric_name: str):
Expand All @@ -284,37 +282,36 @@ def check_confusion_matrix_metric_name(metric_name: str):
metric_name = metric_name.lower()
if metric_name in ["sensitivity", "recall", "hit_rate", "true_positive_rate", "tpr"]:
return "tpr"
elif metric_name in ["specificity", "selectivity", "true_negative_rate", "tnr"]:
if metric_name in ["specificity", "selectivity", "true_negative_rate", "tnr"]:
return "tnr"
elif metric_name in ["precision", "positive_predictive_value", "ppv"]:
if metric_name in ["precision", "positive_predictive_value", "ppv"]:
return "ppv"
elif metric_name in ["negative_predictive_value", "npv"]:
if metric_name in ["negative_predictive_value", "npv"]:
return "npv"
elif metric_name in ["miss_rate", "false_negative_rate", "fnr"]:
if metric_name in ["miss_rate", "false_negative_rate", "fnr"]:
return "fnr"
elif metric_name in ["fall_out", "false_positive_rate", "fpr"]:
if metric_name in ["fall_out", "false_positive_rate", "fpr"]:
return "fpr"
elif metric_name in ["false_discovery_rate", "fdr"]:
if metric_name in ["false_discovery_rate", "fdr"]:
return "fdr"
elif metric_name in ["false_omission_rate", "for"]:
if metric_name in ["false_omission_rate", "for"]:
return "for"
elif metric_name in ["prevalence_threshold", "pt"]:
if metric_name in ["prevalence_threshold", "pt"]:
return "pt"
elif metric_name in ["threat_score", "critical_success_index", "ts", "csi"]:
if metric_name in ["threat_score", "critical_success_index", "ts", "csi"]:
return "ts"
elif metric_name in ["accuracy", "acc"]:
if metric_name in ["accuracy", "acc"]:
return "acc"
elif metric_name in ["balanced_accuracy", "ba"]:
if metric_name in ["balanced_accuracy", "ba"]:
return "ba"
elif metric_name in ["f1_score", "f1"]:
if metric_name in ["f1_score", "f1"]:
return "f1"
elif metric_name in ["matthews_correlation_coefficient", "mcc"]:
if metric_name in ["matthews_correlation_coefficient", "mcc"]:
return "mcc"
elif metric_name in ["fowlkes_mallows_index", "fm"]:
if metric_name in ["fowlkes_mallows_index", "fm"]:
return "fm"
elif metric_name in ["informedness", "bookmaker_informedness", "bm"]:
if metric_name in ["informedness", "bookmaker_informedness", "bm"]:
return "bm"
elif metric_name in ["markedness", "deltap", "mk"]:
if metric_name in ["markedness", "deltap", "mk"]:
return "mk"
else:
raise NotImplementedError("the metric is not implemented.")
raise NotImplementedError("the metric is not implemented.")
5 changes: 2 additions & 3 deletions monai/metrics/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def compute_percent_hausdorff_distance(
if not percentile:
return surface_distance.max()

elif 0 <= percentile <= 100:
if 0 <= percentile <= 100:
return np.percentile(surface_distance, percentile)
else:
raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.")
raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.")
57 changes: 27 additions & 30 deletions monai/metrics/rocauc.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,33 +114,30 @@ def compute_roc_auc(
if softmax:
warnings.warn("y_pred has only one channel, softmax=True ignored.")
return _calculate(y, y_pred)
else:
n_classes = y_pred.shape[1]
if to_onehot_y:
y = one_hot(y, n_classes)
if softmax and other_act is not None:
raise ValueError("Incompatible values: softmax=True and other_act is not None.")
if softmax:
y_pred = y_pred.float().softmax(dim=1)
if other_act is not None:
if not callable(other_act):
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
y_pred = other_act(y_pred)

assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match."

average = Average(average)
if average == Average.MICRO:
return _calculate(y.flatten(), y_pred.flatten())
y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)
auc_values = [_calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)]
if average == Average.NONE:
return auc_values
if average == Average.MACRO:
return np.mean(auc_values)
if average == Average.WEIGHTED:
weights = [sum(y_) for y_ in y]
return np.average(auc_values, weights=weights)
raise ValueError(
f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].'
)
n_classes = y_pred.shape[1]
if to_onehot_y:
y = one_hot(y, n_classes)
if softmax and other_act is not None:
raise ValueError("Incompatible values: softmax=True and other_act is not None.")
if softmax:
y_pred = y_pred.float().softmax(dim=1)
if other_act is not None:
if not callable(other_act):
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
y_pred = other_act(y_pred)

assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match."

average = Average(average)
if average == Average.MICRO:
return _calculate(y.flatten(), y_pred.flatten())
y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)
auc_values = [_calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)]
if average == Average.NONE:
return auc_values
if average == Average.MACRO:
return np.mean(auc_values)
if average == Average.WEIGHTED:
weights = [sum(y_) for y_ in y]
return np.average(auc_values, weights=weights)
raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].')
11 changes: 5 additions & 6 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,13 @@ def split_args(args):

if isinstance(args, str):
return args, {}
else:
name_obj, name_args = args
name_obj, name_args = args

if not isinstance(name_obj, (str, Callable)) or not isinstance(name_args, dict):
msg = "Layer specifiers must be single strings or pairs of the form (name/object-types, argument dict)"
raise TypeError(msg)
if not isinstance(name_obj, (str, Callable)) or not isinstance(name_args, dict):
msg = "Layer specifiers must be single strings or pairs of the form (name/object-types, argument dict)"
raise TypeError(msg)

return name_obj, name_args
return name_obj, name_args


# Define factories for these layer types
Expand Down
Loading