From 4a6a690725907e0f08d054d8323ad563ea8b9bee Mon Sep 17 00:00:00 2001 From: milesial Date: Wed, 22 Feb 2023 03:30:07 +0100 Subject: [PATCH 1/3] Add weight=None option for MixedOp Signed-off-by: Alexandre Milesi --- monai/networks/nets/dints.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py index eea9f351d8..814814ab83 100644 --- a/monai/networks/nets/dints.py +++ b/monai/networks/nets/dints.py @@ -13,12 +13,12 @@ import datetime import warnings +from typing import Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F - from monai.networks.blocks.dints_block import ( ActiConvNormBlock, FactorizedIncreaseBlock, @@ -40,7 +40,7 @@ class CellInterface(torch.nn.Module): """interface for torchscriptable Cell""" - def forward(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: # type: ignore + def forward(self, x: torch.Tensor, weight: Optional[torch.Tensor]) -> torch.Tensor: # type: ignore pass @@ -170,7 +170,7 @@ def __init__(self, c: int, ops: dict, arch_code_c=None): if arch_c > 0: self.ops.append(ops[op_name](c)) - def forward(self, x: torch.Tensor, weight: torch.Tensor): + def forward(self, x: torch.Tensor, weight: Optional[torch.Tensor] = None): """ Args: x: input tensor. @@ -179,9 +179,10 @@ def forward(self, x: torch.Tensor, weight: torch.Tensor): out: weighted average of the operation results. """ out = 0.0 - weight = weight.to(x) + if weight is not None: + weight = weight.to(x) for idx, _op in enumerate(self.ops): - out = out + _op(x) * weight[idx] + out = (out + _op(x)) if weight is None else out + _op(x) * weight[idx] return out @@ -297,7 +298,7 @@ def __init__( self.op = MixedOp(c, self.OPS, arch_code_c) - def forward(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, weight: Optional[torch.Tensor]) -> torch.Tensor: """ Args: x: input tensor @@ -669,14 +670,14 @@ def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]: x: input tensor. """ # generate path activation probability - inputs, outputs = x, [torch.tensor(0.0).to(x[0])] * self.num_depths + inputs = x for blk_idx in range(self.num_blocks): - outputs = [torch.tensor(0.0).to(x[0])] * self.num_depths + outputs = [torch.tensor(0.0, dtype=x[0].dtype, device=x[0].device)] * self.num_depths for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data): if activation: mod: CellInterface = self.cell_tree[str((blk_idx, res_idx))] _out = mod.forward( - x=inputs[self.arch_code2in[res_idx]], weight=torch.ones_like(self.arch_code_c[blk_idx, res_idx]) + x=inputs[self.arch_code2in[res_idx]], weight=None ) outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out inputs = outputs @@ -885,13 +886,13 @@ def get_ram_cost_usage(self, in_size, full: bool = False): sizes = [] for res_idx in range(self.num_depths): sizes.append(batch_size * self.filter_nums[res_idx] * (image_size // (2**res_idx)).prod()) - sizes = torch.tensor(sizes).to(torch.float32).to(self.device) / (2 ** (int(self.use_downsample))) + sizes = torch.tensor(sizes, dtype=torch.float32, device=self.device) / (2 ** (int(self.use_downsample))) probs_a, arch_code_prob_a = self.get_prob_a(child=False) cell_prob = F.softmax(self.log_alpha_c, dim=-1) if full: arch_code_prob_a = arch_code_prob_a.detach() arch_code_prob_a.fill_(1) - ram_cost = torch.from_numpy(self.ram_cost).to(torch.float32).to(self.device) + ram_cost = torch.from_numpy(self.ram_cost).to(dtype=torch.float32, device=self.device) usage = 0.0 for blk_idx in range(self.num_blocks): # node activation for input From e2841b3aa08b117c3df354416078b544e3c0b496 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Feb 2023 22:03:10 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/dints.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py index 814814ab83..459c436b90 100644 --- a/monai/networks/nets/dints.py +++ b/monai/networks/nets/dints.py @@ -13,7 +13,6 @@ import datetime import warnings -from typing import Optional import numpy as np import torch @@ -40,7 +39,7 @@ class CellInterface(torch.nn.Module): """interface for torchscriptable Cell""" - def forward(self, x: torch.Tensor, weight: Optional[torch.Tensor]) -> torch.Tensor: # type: ignore + def forward(self, x: torch.Tensor, weight: torch.Tensor | None) -> torch.Tensor: # type: ignore pass @@ -170,7 +169,7 @@ def __init__(self, c: int, ops: dict, arch_code_c=None): if arch_c > 0: self.ops.append(ops[op_name](c)) - def forward(self, x: torch.Tensor, weight: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, weight: torch.Tensor | None = None): """ Args: x: input tensor. @@ -298,7 +297,7 @@ def __init__( self.op = MixedOp(c, self.OPS, arch_code_c) - def forward(self, x: torch.Tensor, weight: Optional[torch.Tensor]) -> torch.Tensor: + def forward(self, x: torch.Tensor, weight: torch.Tensor | None) -> torch.Tensor: """ Args: x: input tensor From c1d80567086edea92ee4de6e95177bad6c79830f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 22 Feb 2023 23:13:00 +0000 Subject: [PATCH 3/3] ignore type hints in jit Signed-off-by: Wenqi Li --- monai/networks/nets/dints.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py index 459c436b90..437789ef0c 100644 --- a/monai/networks/nets/dints.py +++ b/monai/networks/nets/dints.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F + from monai.networks.blocks.dints_block import ( ActiConvNormBlock, FactorizedIncreaseBlock, @@ -39,7 +40,7 @@ class CellInterface(torch.nn.Module): """interface for torchscriptable Cell""" - def forward(self, x: torch.Tensor, weight: torch.Tensor | None) -> torch.Tensor: # type: ignore + def forward(self, x: torch.Tensor, weight) -> torch.Tensor: # type: ignore pass @@ -675,9 +676,7 @@ def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]: for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data): if activation: mod: CellInterface = self.cell_tree[str((blk_idx, res_idx))] - _out = mod.forward( - x=inputs[self.arch_code2in[res_idx]], weight=None - ) + _out = mod.forward(x=inputs[self.arch_code2in[res_idx]], weight=None) outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out inputs = outputs