Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor):
dims = y_pred.ndimension()
if dims < 2:
raise ValueError("y_pred should have at least two dimensions.")
elif dims == 2 or (dims == 3 and y_pred.shape[-1] == 1):
if dims == 2 or (dims == 3 and y_pred.shape[-1] == 1):
if self.compute_sample:
warnings.warn("As for classification task, compute_sample should be False.")
self.compute_sample = False
Expand Down
11 changes: 5 additions & 6 deletions monai/networks/blocks/dynunet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,12 @@ def get_acti_layer(act: Union[Tuple[str, Dict], str]):
def get_norm_layer(spatial_dims: int, out_channels: int, norm_name: str, num_groups: int = 16):
if norm_name not in ["batch", "instance", "group"]:
raise ValueError(f"Unsupported normalization mode: {norm_name}")
if norm_name == "group":
assert out_channels % num_groups == 0, "out_channels should be divisible by num_groups."
norm = Norm[norm_name](num_groups=num_groups, num_channels=out_channels, affine=True)
else:
if norm_name == "group":
assert out_channels % num_groups == 0, "out_channels should be divisible by num_groups."
norm = Norm[norm_name](num_groups=num_groups, num_channels=out_channels, affine=True)
else:
norm = Norm[norm_name, spatial_dims](out_channels, affine=True)
return norm
norm = Norm[norm_name, spatial_dims](out_channels, affine=True)
return norm


def get_conv_layer(
Expand Down
17 changes: 8 additions & 9 deletions monai/networks/blocks/segresnet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@
def get_norm_layer(spatial_dims: int, in_channels: int, norm_name: str, num_groups: int = 8):
if norm_name not in ["batch", "instance", "group"]:
raise ValueError(f"Unsupported normalization mode: {norm_name}")
if norm_name == "group":
norm = Norm[norm_name](num_groups=num_groups, num_channels=in_channels)
else:
if norm_name == "group":
norm = Norm[norm_name](num_groups=num_groups, num_channels=in_channels)
else:
norm = Norm[norm_name, spatial_dims](in_channels)
if norm.bias is not None:
nn.init.zeros_(norm.bias)
if norm.weight is not None:
nn.init.ones_(norm.weight)
return norm
norm = Norm[norm_name, spatial_dims](in_channels)
if norm.bias is not None:
nn.init.zeros_(norm.bias)
if norm.weight is not None:
nn.init.ones_(norm.weight)
return norm


def get_conv_layer(
Expand Down
3 changes: 1 addition & 2 deletions monai/networks/layers/simplelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.as_tensor(x, device=x.device if torch.is_tensor(x) else None)
if torch.is_complex(x):
raise ValueError("x must be real.")
else:
x = x.to(dtype=torch.float)
x = x.to(dtype=torch.float)

if (self.axis < 0) or (self.axis > len(x.shape) - 1):
raise ValueError("Invalid axis for shape of x.")
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def create_rotate(spatial_dims: int, radians: Union[Sequence[float], float]) ->
return np.array([[cos_, -sin_, 0.0], [sin_, cos_, 0.0], [0.0, 0.0, 1.0]])
raise ValueError("radians must be non empty.")

elif spatial_dims == 3:
if spatial_dims == 3:
affine = None
if len(radians) >= 1:
sin_, cos_ = np.sin(radians[0]), np.cos(radians[0])
Expand Down
3 changes: 1 addition & 2 deletions monai/utils/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def resolve_name(name):
modnames = [m.__name__ for m in foundmods]
msg = f"Multiple modules ({modnames!r}) with declaration name {name!r} found, resolution is ambiguous."
raise ValueError(msg)
else:
mods = list(foundmods)
mods = list(foundmods)

obj = getattr(mods[0], name)

Expand Down
3 changes: 1 addition & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,7 @@ def _wrapper(*args, **kwargs):
if isinstance(res, Exception): # other errors from obj
if hasattr(res, "traceback"):
raise RuntimeError(res.traceback) from res
else:
raise res
raise res
if timeout_error: # no force_quit finished
raise timeout_error
return res
Expand Down