From 714cba4390e6b07097263bb1649c129699fc3ff0 Mon Sep 17 00:00:00 2001 From: dongy Date: Thu, 21 Apr 2022 10:50:03 -0700 Subject: [PATCH 1/3] fixed a bug Signed-off-by: dongy --- monai/networks/blocks/dints_block.py | 42 ++- monai/networks/nets/dints.py | 407 ++++++++++++++++++++++----- 2 files changed, 361 insertions(+), 88 deletions(-) diff --git a/monai/networks/blocks/dints_block.py b/monai/networks/blocks/dints_block.py index f76e125fe0..c40746b634 100644 --- a/monai/networks/blocks/dints_block.py +++ b/monai/networks/blocks/dints_block.py @@ -17,7 +17,12 @@ from monai.networks.layers.factories import Conv from monai.networks.layers.utils import get_act_layer, get_norm_layer -__all__ = ["FactorizedIncreaseBlock", "FactorizedReduceBlock", "P3DActiConvNormBlock", "ActiConvNormBlock"] +__all__ = [ + "FactorizedIncreaseBlock", + "FactorizedReduceBlock", + "P3DActiConvNormBlock", + "ActiConvNormBlock", +] class FactorizedIncreaseBlock(torch.nn.Sequential): @@ -31,7 +36,7 @@ def __init__( out_channel: int, spatial_dims: int = 3, act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = "INSTANCE", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), ): """ Args: @@ -50,7 +55,9 @@ def __init__( conv_type = Conv[Conv.CONV, self._spatial_dims] mode = "trilinear" if self._spatial_dims == 3 else "bilinear" - self.add_module("up", torch.nn.Upsample(scale_factor=2, mode=mode, align_corners=True)) + self.add_module( + "up", torch.nn.Upsample(scale_factor=2, mode=mode, align_corners=True) + ) self.add_module("acti", get_act_layer(name=act_name)) self.add_module( "conv", @@ -66,7 +73,12 @@ def __init__( ), ) self.add_module( - "norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) + "norm", + get_norm_layer( + name=norm_name, + spatial_dims=self._spatial_dims, + channels=self._out_channel, + ), ) @@ -82,7 +94,7 @@ def __init__( out_channel: int, spatial_dims: int = 3, act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = "INSTANCE", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), ): """ Args: @@ -122,7 +134,9 @@ def __init__( bias=False, dilation=1, ) - self.norm = get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) + self.norm = get_norm_layer( + name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -150,7 +164,7 @@ def __init__( padding: int, mode: int = 0, act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = "INSTANCE", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), ): """ Args: @@ -219,7 +233,10 @@ def __init__( dilation=1, ), ) - self.add_module("norm", get_norm_layer(name=norm_name, spatial_dims=3, channels=self._out_channel)) + self.add_module( + "norm", + get_norm_layer(name=norm_name, spatial_dims=3, channels=self._out_channel), + ) class ActiConvNormBlock(torch.nn.Sequential): @@ -235,7 +252,7 @@ def __init__( padding: int = 1, spatial_dims: int = 3, act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = "INSTANCE", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), ): """ Args: @@ -268,5 +285,10 @@ def __init__( ), ) self.add_module( - "norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) + "norm", + get_norm_layer( + name=norm_name, + spatial_dims=self._spatial_dims, + channels=self._out_channel, + ), ) diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py index 978695c5d0..7b6051d4aa 100644 --- a/monai/networks/nets/dints.py +++ b/monai/networks/nets/dints.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import warnings from typing import List, Optional, Tuple, Union @@ -96,22 +97,57 @@ class _ActiConvNormBlockWithRAMCost(ActiConvNormBlock): ram_cost = total_ram/output_size = 2 * in_channel/out_channel + 1 """ - def __init__(self, in_channel: int, out_channel: int, kernel_size: int, padding: int, spatial_dims: int = 3): - super().__init__(in_channel, out_channel, kernel_size, padding, spatial_dims) + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: int, + padding: int, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + ): + super().__init__( + in_channel, + out_channel, + kernel_size, + padding, + spatial_dims, + act_name, + norm_name, + ) self.ram_cost = 1 + in_channel / out_channel * 2 class _P3DActiConvNormBlockWithRAMCost(P3DActiConvNormBlock): - def __init__(self, in_channel: int, out_channel: int, kernel_size: int, padding: int, p3dmode: int = 0): - super().__init__(in_channel, out_channel, kernel_size, padding, p3dmode) + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: int, + padding: int, + p3dmode: int = 0, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + ): + super().__init__( + in_channel, out_channel, kernel_size, padding, p3dmode, act_name, norm_name + ) # 1 in_channel (activation) + 1 in_channel (convolution) + # 1 out_channel (convolution) + 1 out_channel (normalization) self.ram_cost = 2 + 2 * in_channel / out_channel class _FactorizedIncreaseBlockWithRAMCost(FactorizedIncreaseBlock): - def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3): - super().__init__(in_channel, out_channel, spatial_dims) + def __init__( + self, + in_channel: int, + out_channel: int, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + ): + super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name) # s0 is upsampled 2x from s1, representing feature sizes at two resolutions. # 2 * in_channel * s0 (upsample + activation) + 2 * out_channel * s0 (conv + normalization) # s0 = output_size/out_channel @@ -119,12 +155,19 @@ def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3): class _FactorizedReduceBlockWithRAMCost(FactorizedReduceBlock): - def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3): - super().__init__(in_channel, out_channel, spatial_dims) + def __init__( + self, + in_channel: int, + out_channel: int, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + ): + super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name) # s0 is upsampled 2x from s1, representing feature sizes at two resolutions. # in_channel * s0 (activation) + 3 * out_channel * s1 (convolution, concatenation, normalization) # s0 = s1 * 2^(spatial_dims) = output_size / out_channel * 2^(spatial_dims) - self.ram_cost = in_channel / out_channel * 2**self._spatial_dims + 3 + self.ram_cost = in_channel / out_channel * 2 ** self._spatial_dims + 3 class MixedOp(nn.Module): @@ -190,33 +233,105 @@ class Cell(CellInterface): "align_channels": _ActiConvNormBlockWithRAMCost, } - # Define 2D operation set, parameterized by the number of channels - OPS2D = { - "skip_connect": lambda _c: _IdentityWithRAMCost(), - "conv_3x3": lambda c: _ActiConvNormBlockWithRAMCost(c, c, 3, padding=1, spatial_dims=2), - } - - # Define 3D operation set, parameterized by the number of channels - OPS3D = { - "skip_connect": lambda _c: _IdentityWithRAMCost(), - "conv_3x3x3": lambda c: _ActiConvNormBlockWithRAMCost(c, c, 3, padding=1, spatial_dims=3), - "conv_3x3x1": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=0), - "conv_3x1x3": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=1), - "conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=2), - } - - def __init__(self, c_prev: int, c: int, rate: int, arch_code_c=None, spatial_dims: int = 3): + def __init__( + self, + c_prev: int, + c: int, + rate: int, + arch_code_c=None, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + ): super().__init__() self._spatial_dims = spatial_dims + self._act_name = act_name + self._norm_name = norm_name + if rate == -1: # downsample - self.preprocess = self.ConnOPS["down"](c_prev, c, spatial_dims=self._spatial_dims) + self.preprocess = self.ConnOPS["down"]( + c_prev, + c, + spatial_dims=self._spatial_dims, + act_name=self._act_name, + norm_name=self._norm_name, + ) elif rate == 1: # upsample - self.preprocess = self.ConnOPS["up"](c_prev, c, spatial_dims=self._spatial_dims) + self.preprocess = self.ConnOPS["up"]( + c_prev, + c, + spatial_dims=self._spatial_dims, + act_name=self._act_name, + norm_name=self._norm_name, + ) else: if c_prev == c: self.preprocess = self.ConnOPS["identity"]() else: - self.preprocess = self.ConnOPS["align_channels"](c_prev, c, 1, 0, spatial_dims=self._spatial_dims) + self.preprocess = self.ConnOPS["align_channels"]( + c_prev, + c, + 1, + 0, + spatial_dims=self._spatial_dims, + act_name=self._act_name, + norm_name=self._norm_name, + ) + + # Define 2D operation set, parameterized by the number of channels + self.OPS2D = { + "skip_connect": lambda _c: _IdentityWithRAMCost(), + "conv_3x3": lambda c: _ActiConvNormBlockWithRAMCost( + c, + c, + 3, + padding=1, + spatial_dims=2, + act_name=self._act_name, + norm_name=self._norm_name, + ), + } + + # Define 3D operation set, parameterized by the number of channels + self.OPS3D = { + "skip_connect": lambda _c: _IdentityWithRAMCost(), + "conv_3x3x3": lambda c: _ActiConvNormBlockWithRAMCost( + c, + c, + 3, + padding=1, + spatial_dims=3, + act_name=self._act_name, + norm_name=self._norm_name, + ), + "conv_3x3x1": lambda c: _P3DActiConvNormBlockWithRAMCost( + c, + c, + 3, + padding=1, + p3dmode=0, + act_name=self._act_name, + norm_name=self._norm_name, + ), + "conv_3x1x3": lambda c: _P3DActiConvNormBlockWithRAMCost( + c, + c, + 3, + padding=1, + p3dmode=1, + act_name=self._act_name, + norm_name=self._norm_name, + ), + "conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost( + c, + c, + 3, + padding=1, + p3dmode=2, + act_name=self._act_name, + norm_name=self._norm_name, + ), + } self.OPS = {} if self._spatial_dims == 2: @@ -224,7 +339,9 @@ def __init__(self, c_prev: int, c: int, rate: int, arch_code_c=None, spatial_dim elif self._spatial_dims == 3: self.OPS = self.OPS3D else: - raise NotImplementedError(f"Spatial dimensions {self._spatial_dims} is not supported.") + raise NotImplementedError( + f"Spatial dimensions {self._spatial_dims} is not supported." + ) self.op = MixedOp(c, self.OPS, arch_code_c) @@ -283,7 +400,7 @@ def __init__( in_channels: int, num_classes: int, act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = "INSTANCE", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), spatial_dims: int = 3, use_downsample: bool = True, node_a=None, @@ -295,7 +412,9 @@ def __init__( self.num_blocks = dints_space.num_blocks self.num_depths = dints_space.num_depths if spatial_dims not in (2, 3): - raise NotImplementedError(f"Spatial dimensions {spatial_dims} is not supported.") + raise NotImplementedError( + f"Spatial dimensions {spatial_dims} is not supported." + ) self._spatial_dims = spatial_dims if node_a is None: self.node_a = torch.ones((self.num_blocks + 1, self.num_depths)) @@ -330,7 +449,9 @@ def __init__( # define downsample stems before DiNTS search if use_downsample: self.stem_down[str(res_idx)] = StemTS( - nn.Upsample(scale_factor=1 / (2**res_idx), mode=mode, align_corners=True), + nn.Upsample( + scale_factor=1 / (2 ** res_idx), mode=mode, align_corners=True + ), conv_type( in_channels=in_channels, out_channels=self.filter_nums[res_idx], @@ -341,7 +462,11 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]), + get_norm_layer( + name=norm_name, + spatial_dims=spatial_dims, + channels=self.filter_nums[res_idx], + ), get_act_layer(name=act_name), conv_type( in_channels=self.filter_nums[res_idx], @@ -353,7 +478,11 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx + 1]), + get_norm_layer( + name=norm_name, + spatial_dims=spatial_dims, + channels=self.filter_nums[res_idx + 1], + ), ) self.stem_up[str(res_idx)] = StemTS( get_act_layer(name=act_name), @@ -367,13 +496,19 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]), + get_norm_layer( + name=norm_name, + spatial_dims=spatial_dims, + channels=self.filter_nums[res_idx], + ), nn.Upsample(scale_factor=2, mode=mode, align_corners=True), ) else: self.stem_down[str(res_idx)] = StemTS( - nn.Upsample(scale_factor=1 / (2**res_idx), mode=mode, align_corners=True), + nn.Upsample( + scale_factor=1 / (2 ** res_idx), mode=mode, align_corners=True + ), conv_type( in_channels=in_channels, out_channels=self.filter_nums[res_idx], @@ -384,7 +519,11 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]), + get_norm_layer( + name=norm_name, + spatial_dims=spatial_dims, + channels=self.filter_nums[res_idx], + ), ) self.stem_up[str(res_idx)] = StemTS( get_act_layer(name=act_name), @@ -398,8 +537,14 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx - 1]), - nn.Upsample(scale_factor=2 ** (res_idx != 0), mode=mode, align_corners=True), + get_norm_layer( + name=norm_name, + spatial_dims=spatial_dims, + channels=self.filter_nums[res_idx - 1], + ), + nn.Upsample( + scale_factor=2 ** (res_idx != 0), mode=mode, align_corners=True + ), ) def weight_parameters(self): @@ -484,16 +629,22 @@ def __init__( num_blocks: int = 6, num_depths: int = 3, spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), use_downsample: bool = True, device: str = "cpu", ): super().__init__() - self.filter_nums = [int(n_feat * channel_mul) for n_feat in (32, 64, 128, 256, 512)] + self.filter_nums = [ + int(n_feat * channel_mul) for n_feat in (32, 64, 128, 256, 512) + ] self.num_blocks = num_blocks self.num_depths = num_depths self._spatial_dims = spatial_dims + self._act_name = act_name + self._norm_name = norm_name self.use_downsample = use_downsample self.device = device self.num_cell_ops = 0 @@ -505,7 +656,9 @@ def __init__( # Calculate predefined parameters for topology search and decoding arch_code2in, arch_code2out = [], [] for i in range(Cell.DIRECTIONS * self.num_depths - 2): - arch_code2in.append((i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS) + arch_code2in.append( + (i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS + ) arch_code2ops = ([-1, 0, 1] * self.num_depths)[1:-1] for m in range(self.num_depths): arch_code2out.extend([m, m, m]) @@ -516,11 +669,17 @@ def __init__( # define NAS search space if arch_code is None: - arch_code_a = torch.ones((self.num_blocks, len(self.arch_code2out))).to(self.device) - arch_code_c = torch.ones((self.num_blocks, len(self.arch_code2out), self.num_cell_ops)).to(self.device) + arch_code_a = torch.ones((self.num_blocks, len(self.arch_code2out))).to( + self.device + ) + arch_code_c = torch.ones( + (self.num_blocks, len(self.arch_code2out), self.num_cell_ops) + ).to(self.device) else: arch_code_a = torch.from_numpy(arch_code[0]).to(self.device) - arch_code_c = F.one_hot(torch.from_numpy(arch_code[1]).to(torch.int64), self.num_cell_ops).to(self.device) + arch_code_c = F.one_hot( + torch.from_numpy(arch_code[1]).to(torch.int64), self.num_cell_ops + ).to(self.device) self.arch_code_a = arch_code_a self.arch_code_c = arch_code_c @@ -530,11 +689,17 @@ def __init__( for res_idx in range(len(self.arch_code2out)): if self.arch_code_a[blk_idx, res_idx] == 1: self.cell_tree[str((blk_idx, res_idx))] = cell( - self.filter_nums[self.arch_code2in[res_idx] + int(use_downsample)], - self.filter_nums[self.arch_code2out[res_idx] + int(use_downsample)], + self.filter_nums[ + self.arch_code2in[res_idx] + int(use_downsample) + ], + self.filter_nums[ + self.arch_code2out[res_idx] + int(use_downsample) + ], self.arch_code2ops[res_idx], self.arch_code_c[blk_idx, res_idx], self._spatial_dims, + self._act_name, + self._norm_name, ) def forward(self, x): @@ -555,6 +720,8 @@ def __init__( num_blocks: int = 6, num_depths: int = 3, spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), use_downsample: bool = True, device: str = "cpu", ): @@ -571,6 +738,8 @@ def __init__( num_blocks=num_blocks, num_depths=num_depths, spatial_dims=spatial_dims, + act_name=act_name, + norm_name=norm_name, use_downsample=use_downsample, device=device, ) @@ -588,10 +757,13 @@ def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: if activation: mod: CellInterface = self.cell_tree[str((blk_idx, res_idx))] _out = mod.forward( - x=inputs[self.arch_code2in[res_idx]], weight=torch.ones_like(self.arch_code_c[blk_idx, res_idx]) + x=inputs[self.arch_code2in[res_idx]], + weight=torch.ones_like(self.arch_code_c[blk_idx, res_idx]), + ) + outputs[self.arch_code2out[res_idx]] = ( + outputs[self.arch_code2out[res_idx]] + _out ) - outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out - inputs = outputs + inputs = outputs return inputs @@ -650,6 +822,8 @@ def __init__( num_blocks: int = 6, num_depths: int = 3, spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), use_downsample: bool = True, device: str = "cpu", ): @@ -663,6 +837,8 @@ def __init__( num_blocks=num_blocks, num_depths=num_depths, spatial_dims=spatial_dims, + act_name=act_name, + norm_name=norm_name, use_downsample=use_downsample, device=device, ) @@ -670,23 +846,34 @@ def __init__( tidx = [] _d = Cell.DIRECTIONS for i in range(_d * self.num_depths - 2): - tidx.append((i + 1) // _d * self.num_depths + (i + 1) // _d - 1 + (i + 1) % _d) + tidx.append( + (i + 1) // _d * self.num_depths + (i + 1) // _d - 1 + (i + 1) % _d + ) self.tidx = tidx transfer_mtx, node_act_list, child_list = self.gen_mtx(num_depths) self.node_act_list = np.asarray(node_act_list) - self.node_act_dict = {str(self.node_act_list[i]): i for i in range(len(self.node_act_list))} + self.node_act_dict = { + str(self.node_act_list[i]): i for i in range(len(self.node_act_list)) + } self.transfer_mtx = transfer_mtx self.child_list = np.asarray(child_list) - self.ram_cost = np.zeros((self.num_blocks, len(self.arch_code2out), self.num_cell_ops)) + self.ram_cost = np.zeros( + (self.num_blocks, len(self.arch_code2out), self.num_cell_ops) + ) for blk_idx in range(self.num_blocks): for res_idx in range(len(self.arch_code2out)): if self.arch_code_a[blk_idx, res_idx] == 1: self.ram_cost[blk_idx, res_idx] = np.array( [ - op.ram_cost + self.cell_tree[str((blk_idx, res_idx))].preprocess.ram_cost - for op in self.cell_tree[str((blk_idx, res_idx))].op.ops[: self.num_cell_ops] + op.ram_cost + + self.cell_tree[ + str((blk_idx, res_idx)) + ].preprocess.ram_cost + for op in self.cell_tree[str((blk_idx, res_idx))].op.ops[ + : self.num_cell_ops + ] ] ) @@ -698,7 +885,10 @@ def __init__( .requires_grad_() ) self.log_alpha_a = nn.Parameter( - torch.zeros(self.num_blocks, len(self.arch_code2out)).normal_(0, 0.01).to(self.device).requires_grad_() + torch.zeros(self.num_blocks, len(self.arch_code2out)) + .normal_(0, 0.01) + .to(self.device) + .requires_grad_() ) self._arch_param_names = ["log_alpha_a", "log_alpha_c"] @@ -728,7 +918,10 @@ def gen_mtx(self, depth: int): # convert path activation [1,paths] to path activation matrix [depth, depth] ma = np.zeros((depth, depth)) for i in range(paths): - ma[(i + 1) // Cell.DIRECTIONS, (i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS] = m[i] + ma[ + (i + 1) // Cell.DIRECTIONS, + (i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS, + ] = m[i] mtx.append(ma) # define all possible node activation @@ -736,13 +929,21 @@ def gen_mtx(self, depth: int): transfer_mtx = {} for arch_code in node_act_list: # make sure each activated node has an active connection, inactivated node has no connection - arch_code_mtx = [_ for _ in mtx if ((np.sum(_, 0) > 0).astype(int) == np.array(arch_code)).all()] + arch_code_mtx = [ + _ + for _ in mtx + if ((np.sum(_, 0) > 0).astype(int) == np.array(arch_code)).all() + ] transfer_mtx[str(np.array(arch_code))] = arch_code_mtx return transfer_mtx, node_act_list, all_connect[1:] def weight_parameters(self): - return [param for name, param in self.named_parameters() if name not in self._arch_param_names] + return [ + param + for name, param in self.named_parameters() + if name not in self._arch_param_names + ] def get_prob_a(self, child: bool = False): """ @@ -789,8 +990,14 @@ def get_ram_cost_usage(self, in_size, full: bool = False): image_size = np.array(in_size[-self._spatial_dims :]) sizes = [] for res_idx in range(self.num_depths): - sizes.append(batch_size * self.filter_nums[res_idx] * (image_size // (2**res_idx)).prod()) - sizes = torch.tensor(sizes).to(torch.float32).to(self.device) / (2 ** (int(self.use_downsample))) + sizes.append( + batch_size + * self.filter_nums[res_idx] + * (image_size // (2 ** res_idx)).prod() + ) + sizes = torch.tensor(sizes).to(torch.float32).to(self.device) / ( + 2 ** (int(self.use_downsample)) + ) probs_a, arch_code_prob_a = self.get_prob_a(child=False) cell_prob = F.softmax(self.log_alpha_c, dim=-1) if full: @@ -804,10 +1011,15 @@ def get_ram_cost_usage(self, in_size, full: bool = False): for path_idx in range(len(self.arch_code2out)): usage += ( arch_code_prob_a[blk_idx, path_idx] - * (1 + (ram_cost[blk_idx, path_idx] * cell_prob[blk_idx, path_idx]).sum()) + * ( + 1 + + ( + ram_cost[blk_idx, path_idx] * cell_prob[blk_idx, path_idx] + ).sum() + ) * sizes[self.arch_code2out[path_idx]] ) - return usage * 32 / 8 / 1024**2 + return usage * 32 / 8 / 1024 ** 2 def get_topology_entropy(self, probs): """ @@ -825,10 +1037,16 @@ def get_topology_entropy(self, probs): # node activation index to feasible output child_idx node2out = [[] for _ in range(len(self.node_act_list))] for child_idx in range(len(self.child_list)): - _node_in, _node_out = np.zeros(self.num_depths), np.zeros(self.num_depths) + _node_in, _node_out = np.zeros(self.num_depths), np.zeros( + self.num_depths + ) for res_idx in range(len(self.arch_code2out)): - _node_out[self.arch_code2out[res_idx]] += self.child_list[child_idx][res_idx] - _node_in[self.arch_code2in[res_idx]] += self.child_list[child_idx][res_idx] + _node_out[self.arch_code2out[res_idx]] += self.child_list[ + child_idx + ][res_idx] + _node_in[self.arch_code2in[res_idx]] += self.child_list[child_idx][ + res_idx + ] _node_in = (_node_in >= 1).astype(int) _node_out = (_node_out >= 1).astype(int) node2in[self.node_act_dict[str(_node_out)]].append(child_idx) @@ -843,7 +1061,10 @@ def get_topology_entropy(self, probs): for node_idx in range(len(self.node_act_list)): _node_p = probs[blk_idx, node2in[node_idx]].sum() _out_probs = probs[blk_idx + 1, node2out[node_idx]].sum() - blk_ent += -(_node_p * torch.log(_out_probs + 1e-5) + (1 - _node_p) * torch.log(1 - _out_probs + 1e-5)) + blk_ent += -( + _node_p * torch.log(_out_probs + 1e-5) + + (1 - _node_p) * torch.log(1 - _out_probs + 1e-5) + ) ent += blk_ent return ent @@ -865,12 +1086,17 @@ def decode(self): """ probs, arch_code_prob_a = self.get_prob_a(child=True) arch_code_a_max = self.child_list[torch.argmax(probs, -1).data.cpu().numpy()] - arch_code_c = torch.argmax(F.softmax(self.log_alpha_c, -1), -1).data.cpu().numpy() + arch_code_c = ( + torch.argmax(F.softmax(self.log_alpha_c, -1), -1).data.cpu().numpy() + ) probs = probs.data.cpu().numpy() # define adjacency matrix amtx = np.zeros( - (1 + len(self.child_list) * self.num_blocks + 1, 1 + len(self.child_list) * self.num_blocks + 1) + ( + 1 + len(self.child_list) * self.num_blocks + 1, + 1 + len(self.child_list) * self.num_blocks + 1, + ) ) # build a path activation to child index searching dictionary @@ -881,10 +1107,14 @@ def decode(self): for child_idx in range(len(self.child_list)): _node_act = np.zeros(self.num_depths).astype(int) for path_idx in range(len(self.child_list[child_idx])): - _node_act[self.arch_code2out[path_idx]] += self.child_list[child_idx][path_idx] + _node_act[self.arch_code2out[path_idx]] += self.child_list[child_idx][ + path_idx + ] _node_act = (_node_act >= 1).astype(int) for mtx in self.transfer_mtx[str(_node_act)]: - connect_child_idx = path2child[str(mtx.flatten()[self.tidx].astype(int))] + connect_child_idx = path2child[ + str(mtx.flatten()[self.tidx].astype(int)) + ] sub_amtx[child_idx, connect_child_idx] = 1 # fill in source to first block, add 1e-5/1e-3 to avoid log0 and negative edge weights @@ -893,16 +1123,31 @@ def decode(self): # fill in the rest blocks for blk_idx in range(1, self.num_blocks): amtx[ - 1 + (blk_idx - 1) * len(self.child_list) : 1 + blk_idx * len(self.child_list), - 1 + blk_idx * len(self.child_list) : 1 + (blk_idx + 1) * len(self.child_list), - ] = sub_amtx * np.tile(-np.log(probs[blk_idx] + 1e-5) + 0.001, (len(self.child_list), 1)) + 1 + + (blk_idx - 1) * len(self.child_list) : 1 + + blk_idx * len(self.child_list), + 1 + + blk_idx * len(self.child_list) : 1 + + (blk_idx + 1) * len(self.child_list), + ] = sub_amtx * np.tile( + -np.log(probs[blk_idx] + 1e-5) + 0.001, (len(self.child_list), 1) + ) # fill in the last to the sink - amtx[1 + (self.num_blocks - 1) * len(self.child_list) : 1 + self.num_blocks * len(self.child_list), -1] = 0.001 + amtx[ + 1 + + (self.num_blocks - 1) * len(self.child_list) : 1 + + self.num_blocks * len(self.child_list), + -1, + ] = 0.001 graph = csr_matrix(amtx) dist_matrix, predecessors, sources = dijkstra( - csgraph=graph, directed=True, indices=0, min_only=True, return_predecessors=True + csgraph=graph, + directed=True, + indices=0, + min_only=True, + return_predecessors=True, ) index, a_idx = -1, -1 arch_code_a = np.zeros((self.num_blocks, len(self.arch_code2out))) @@ -916,7 +1161,9 @@ def decode(self): child_idx = (index - 1) % len(self.child_list) arch_code_a[a_idx, :] = self.child_list[child_idx] for res_idx in range(len(self.arch_code2out)): - node_a[a_idx, self.arch_code2out[res_idx]] += arch_code_a[a_idx, res_idx] + node_a[a_idx, self.arch_code2out[res_idx]] += arch_code_a[ + a_idx, res_idx + ] a_idx -= 1 for res_idx in range(len(self.arch_code2out)): node_a[a_idx, self.arch_code2in[res_idx]] += arch_code_a[0, res_idx] @@ -936,11 +1183,15 @@ def forward(self, x): inputs = x for blk_idx in range(self.num_blocks): outputs = [0.0] * self.num_depths - for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data.cpu().numpy()): + for res_idx, activation in enumerate( + self.arch_code_a[blk_idx].data.cpu().numpy() + ): if activation: _w = F.softmax(self.log_alpha_c[blk_idx, res_idx], dim=-1) outputs[self.arch_code2out[res_idx]] += ( - self.cell_tree[str((blk_idx, res_idx))](inputs[self.arch_code2in[res_idx]], weight=_w) + self.cell_tree[str((blk_idx, res_idx))]( + inputs[self.arch_code2in[res_idx]], weight=_w + ) * arch_code_prob_a[blk_idx, res_idx] ) inputs = outputs From b617703a6d015777ff60cda790b4b7517f396ad5 Mon Sep 17 00:00:00 2001 From: dongy Date: Thu, 21 Apr 2022 11:29:57 -0700 Subject: [PATCH 2/3] autofix Signed-off-by: dongy --- monai/networks/blocks/dints_block.py | 34 +-- monai/networks/nets/dints.py | 309 +++++++-------------------- 2 files changed, 80 insertions(+), 263 deletions(-) diff --git a/monai/networks/blocks/dints_block.py b/monai/networks/blocks/dints_block.py index c40746b634..b7365f50e3 100644 --- a/monai/networks/blocks/dints_block.py +++ b/monai/networks/blocks/dints_block.py @@ -17,12 +17,7 @@ from monai.networks.layers.factories import Conv from monai.networks.layers.utils import get_act_layer, get_norm_layer -__all__ = [ - "FactorizedIncreaseBlock", - "FactorizedReduceBlock", - "P3DActiConvNormBlock", - "ActiConvNormBlock", -] +__all__ = ["FactorizedIncreaseBlock", "FactorizedReduceBlock", "P3DActiConvNormBlock", "ActiConvNormBlock"] class FactorizedIncreaseBlock(torch.nn.Sequential): @@ -55,9 +50,7 @@ def __init__( conv_type = Conv[Conv.CONV, self._spatial_dims] mode = "trilinear" if self._spatial_dims == 3 else "bilinear" - self.add_module( - "up", torch.nn.Upsample(scale_factor=2, mode=mode, align_corners=True) - ) + self.add_module("up", torch.nn.Upsample(scale_factor=2, mode=mode, align_corners=True)) self.add_module("acti", get_act_layer(name=act_name)) self.add_module( "conv", @@ -73,12 +66,7 @@ def __init__( ), ) self.add_module( - "norm", - get_norm_layer( - name=norm_name, - spatial_dims=self._spatial_dims, - channels=self._out_channel, - ), + "norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) ) @@ -134,9 +122,7 @@ def __init__( bias=False, dilation=1, ) - self.norm = get_norm_layer( - name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel - ) + self.norm = get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -233,10 +219,7 @@ def __init__( dilation=1, ), ) - self.add_module( - "norm", - get_norm_layer(name=norm_name, spatial_dims=3, channels=self._out_channel), - ) + self.add_module("norm", get_norm_layer(name=norm_name, spatial_dims=3, channels=self._out_channel)) class ActiConvNormBlock(torch.nn.Sequential): @@ -285,10 +268,5 @@ def __init__( ), ) self.add_module( - "norm", - get_norm_layer( - name=norm_name, - spatial_dims=self._spatial_dims, - channels=self._out_channel, - ), + "norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) ) diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py index 7b6051d4aa..842f19b01b 100644 --- a/monai/networks/nets/dints.py +++ b/monai/networks/nets/dints.py @@ -107,15 +107,7 @@ def __init__( act_name: Union[Tuple, str] = "RELU", norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), ): - super().__init__( - in_channel, - out_channel, - kernel_size, - padding, - spatial_dims, - act_name, - norm_name, - ) + super().__init__(in_channel, out_channel, kernel_size, padding, spatial_dims, act_name, norm_name) self.ram_cost = 1 + in_channel / out_channel * 2 @@ -130,9 +122,7 @@ def __init__( act_name: Union[Tuple, str] = "RELU", norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), ): - super().__init__( - in_channel, out_channel, kernel_size, padding, p3dmode, act_name, norm_name - ) + super().__init__(in_channel, out_channel, kernel_size, padding, p3dmode, act_name, norm_name) # 1 in_channel (activation) + 1 in_channel (convolution) + # 1 out_channel (convolution) + 1 out_channel (normalization) self.ram_cost = 2 + 2 * in_channel / out_channel @@ -167,7 +157,7 @@ def __init__( # s0 is upsampled 2x from s1, representing feature sizes at two resolutions. # in_channel * s0 (activation) + 3 * out_channel * s1 (convolution, concatenation, normalization) # s0 = s1 * 2^(spatial_dims) = output_size / out_channel * 2^(spatial_dims) - self.ram_cost = in_channel / out_channel * 2 ** self._spatial_dims + 3 + self.ram_cost = in_channel / out_channel * 2**self._spatial_dims + 3 class MixedOp(nn.Module): @@ -225,6 +215,21 @@ class Cell(CellInterface): # \ # - Downsample + # Define 2D operation set, parameterized by the number of channels + OPS2D = { + "skip_connect": lambda _c: _IdentityWithRAMCost(), + "conv_3x3": lambda c: _ActiConvNormBlockWithRAMCost(c, c, 3, padding=1, spatial_dims=2), + } + + # Define 3D operation set, parameterized by the number of channels + OPS3D = { + "skip_connect": lambda _c: _IdentityWithRAMCost(), + "conv_3x3x3": lambda c: _ActiConvNormBlockWithRAMCost(c, c, 3, padding=1, spatial_dims=3), + "conv_3x3x1": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=0), + "conv_3x1x3": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=1), + "conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=2), + } + # Define connection operation set, parameterized by the number of channels ConnOPS = { "up": _FactorizedIncreaseBlockWithRAMCost, @@ -250,45 +255,25 @@ def __init__( if rate == -1: # downsample self.preprocess = self.ConnOPS["down"]( - c_prev, - c, - spatial_dims=self._spatial_dims, - act_name=self._act_name, - norm_name=self._norm_name, + c_prev, c, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name ) elif rate == 1: # upsample self.preprocess = self.ConnOPS["up"]( - c_prev, - c, - spatial_dims=self._spatial_dims, - act_name=self._act_name, - norm_name=self._norm_name, + c_prev, c, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name ) else: if c_prev == c: self.preprocess = self.ConnOPS["identity"]() else: self.preprocess = self.ConnOPS["align_channels"]( - c_prev, - c, - 1, - 0, - spatial_dims=self._spatial_dims, - act_name=self._act_name, - norm_name=self._norm_name, + c_prev, c, 1, 0, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name ) # Define 2D operation set, parameterized by the number of channels self.OPS2D = { "skip_connect": lambda _c: _IdentityWithRAMCost(), "conv_3x3": lambda c: _ActiConvNormBlockWithRAMCost( - c, - c, - 3, - padding=1, - spatial_dims=2, - act_name=self._act_name, - norm_name=self._norm_name, + c, c, 3, padding=1, spatial_dims=2, act_name=self._act_name, norm_name=self._norm_name ), } @@ -296,40 +281,16 @@ def __init__( self.OPS3D = { "skip_connect": lambda _c: _IdentityWithRAMCost(), "conv_3x3x3": lambda c: _ActiConvNormBlockWithRAMCost( - c, - c, - 3, - padding=1, - spatial_dims=3, - act_name=self._act_name, - norm_name=self._norm_name, + c, c, 3, padding=1, spatial_dims=3, act_name=self._act_name, norm_name=self._norm_name ), "conv_3x3x1": lambda c: _P3DActiConvNormBlockWithRAMCost( - c, - c, - 3, - padding=1, - p3dmode=0, - act_name=self._act_name, - norm_name=self._norm_name, + c, c, 3, padding=1, p3dmode=0, act_name=self._act_name, norm_name=self._norm_name ), "conv_3x1x3": lambda c: _P3DActiConvNormBlockWithRAMCost( - c, - c, - 3, - padding=1, - p3dmode=1, - act_name=self._act_name, - norm_name=self._norm_name, + c, c, 3, padding=1, p3dmode=1, act_name=self._act_name, norm_name=self._norm_name ), "conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost( - c, - c, - 3, - padding=1, - p3dmode=2, - act_name=self._act_name, - norm_name=self._norm_name, + c, c, 3, padding=1, p3dmode=2, act_name=self._act_name, norm_name=self._norm_name ), } @@ -339,9 +300,7 @@ def __init__( elif self._spatial_dims == 3: self.OPS = self.OPS3D else: - raise NotImplementedError( - f"Spatial dimensions {self._spatial_dims} is not supported." - ) + raise NotImplementedError(f"Spatial dimensions {self._spatial_dims} is not supported.") self.op = MixedOp(c, self.OPS, arch_code_c) @@ -412,9 +371,7 @@ def __init__( self.num_blocks = dints_space.num_blocks self.num_depths = dints_space.num_depths if spatial_dims not in (2, 3): - raise NotImplementedError( - f"Spatial dimensions {spatial_dims} is not supported." - ) + raise NotImplementedError(f"Spatial dimensions {spatial_dims} is not supported.") self._spatial_dims = spatial_dims if node_a is None: self.node_a = torch.ones((self.num_blocks + 1, self.num_depths)) @@ -449,9 +406,7 @@ def __init__( # define downsample stems before DiNTS search if use_downsample: self.stem_down[str(res_idx)] = StemTS( - nn.Upsample( - scale_factor=1 / (2 ** res_idx), mode=mode, align_corners=True - ), + nn.Upsample(scale_factor=1 / (2**res_idx), mode=mode, align_corners=True), conv_type( in_channels=in_channels, out_channels=self.filter_nums[res_idx], @@ -462,11 +417,7 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer( - name=norm_name, - spatial_dims=spatial_dims, - channels=self.filter_nums[res_idx], - ), + get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]), get_act_layer(name=act_name), conv_type( in_channels=self.filter_nums[res_idx], @@ -478,11 +429,7 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer( - name=norm_name, - spatial_dims=spatial_dims, - channels=self.filter_nums[res_idx + 1], - ), + get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx + 1]), ) self.stem_up[str(res_idx)] = StemTS( get_act_layer(name=act_name), @@ -496,19 +443,13 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer( - name=norm_name, - spatial_dims=spatial_dims, - channels=self.filter_nums[res_idx], - ), + get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]), nn.Upsample(scale_factor=2, mode=mode, align_corners=True), ) else: self.stem_down[str(res_idx)] = StemTS( - nn.Upsample( - scale_factor=1 / (2 ** res_idx), mode=mode, align_corners=True - ), + nn.Upsample(scale_factor=1 / (2**res_idx), mode=mode, align_corners=True), conv_type( in_channels=in_channels, out_channels=self.filter_nums[res_idx], @@ -519,11 +460,7 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer( - name=norm_name, - spatial_dims=spatial_dims, - channels=self.filter_nums[res_idx], - ), + get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]), ) self.stem_up[str(res_idx)] = StemTS( get_act_layer(name=act_name), @@ -537,14 +474,8 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer( - name=norm_name, - spatial_dims=spatial_dims, - channels=self.filter_nums[res_idx - 1], - ), - nn.Upsample( - scale_factor=2 ** (res_idx != 0), mode=mode, align_corners=True - ), + get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx - 1]), + nn.Upsample(scale_factor=2 ** (res_idx != 0), mode=mode, align_corners=True), ) def weight_parameters(self): @@ -637,9 +568,7 @@ def __init__( super().__init__() - self.filter_nums = [ - int(n_feat * channel_mul) for n_feat in (32, 64, 128, 256, 512) - ] + self.filter_nums = [int(n_feat * channel_mul) for n_feat in (32, 64, 128, 256, 512)] self.num_blocks = num_blocks self.num_depths = num_depths self._spatial_dims = spatial_dims @@ -656,9 +585,7 @@ def __init__( # Calculate predefined parameters for topology search and decoding arch_code2in, arch_code2out = [], [] for i in range(Cell.DIRECTIONS * self.num_depths - 2): - arch_code2in.append( - (i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS - ) + arch_code2in.append((i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS) arch_code2ops = ([-1, 0, 1] * self.num_depths)[1:-1] for m in range(self.num_depths): arch_code2out.extend([m, m, m]) @@ -669,17 +596,11 @@ def __init__( # define NAS search space if arch_code is None: - arch_code_a = torch.ones((self.num_blocks, len(self.arch_code2out))).to( - self.device - ) - arch_code_c = torch.ones( - (self.num_blocks, len(self.arch_code2out), self.num_cell_ops) - ).to(self.device) + arch_code_a = torch.ones((self.num_blocks, len(self.arch_code2out))).to(self.device) + arch_code_c = torch.ones((self.num_blocks, len(self.arch_code2out), self.num_cell_ops)).to(self.device) else: arch_code_a = torch.from_numpy(arch_code[0]).to(self.device) - arch_code_c = F.one_hot( - torch.from_numpy(arch_code[1]).to(torch.int64), self.num_cell_ops - ).to(self.device) + arch_code_c = F.one_hot(torch.from_numpy(arch_code[1]).to(torch.int64), self.num_cell_ops).to(self.device) self.arch_code_a = arch_code_a self.arch_code_c = arch_code_c @@ -689,12 +610,8 @@ def __init__( for res_idx in range(len(self.arch_code2out)): if self.arch_code_a[blk_idx, res_idx] == 1: self.cell_tree[str((blk_idx, res_idx))] = cell( - self.filter_nums[ - self.arch_code2in[res_idx] + int(use_downsample) - ], - self.filter_nums[ - self.arch_code2out[res_idx] + int(use_downsample) - ], + self.filter_nums[self.arch_code2in[res_idx] + int(use_downsample)], + self.filter_nums[self.arch_code2out[res_idx] + int(use_downsample)], self.arch_code2ops[res_idx], self.arch_code_c[blk_idx, res_idx], self._spatial_dims, @@ -757,12 +674,9 @@ def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: if activation: mod: CellInterface = self.cell_tree[str((blk_idx, res_idx))] _out = mod.forward( - x=inputs[self.arch_code2in[res_idx]], - weight=torch.ones_like(self.arch_code_c[blk_idx, res_idx]), - ) - outputs[self.arch_code2out[res_idx]] = ( - outputs[self.arch_code2out[res_idx]] + _out + x=inputs[self.arch_code2in[res_idx]], weight=torch.ones_like(self.arch_code_c[blk_idx, res_idx]) ) + outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out inputs = outputs return inputs @@ -846,34 +760,23 @@ def __init__( tidx = [] _d = Cell.DIRECTIONS for i in range(_d * self.num_depths - 2): - tidx.append( - (i + 1) // _d * self.num_depths + (i + 1) // _d - 1 + (i + 1) % _d - ) + tidx.append((i + 1) // _d * self.num_depths + (i + 1) // _d - 1 + (i + 1) % _d) self.tidx = tidx transfer_mtx, node_act_list, child_list = self.gen_mtx(num_depths) self.node_act_list = np.asarray(node_act_list) - self.node_act_dict = { - str(self.node_act_list[i]): i for i in range(len(self.node_act_list)) - } + self.node_act_dict = {str(self.node_act_list[i]): i for i in range(len(self.node_act_list))} self.transfer_mtx = transfer_mtx self.child_list = np.asarray(child_list) - self.ram_cost = np.zeros( - (self.num_blocks, len(self.arch_code2out), self.num_cell_ops) - ) + self.ram_cost = np.zeros((self.num_blocks, len(self.arch_code2out), self.num_cell_ops)) for blk_idx in range(self.num_blocks): for res_idx in range(len(self.arch_code2out)): if self.arch_code_a[blk_idx, res_idx] == 1: self.ram_cost[blk_idx, res_idx] = np.array( [ - op.ram_cost - + self.cell_tree[ - str((blk_idx, res_idx)) - ].preprocess.ram_cost - for op in self.cell_tree[str((blk_idx, res_idx))].op.ops[ - : self.num_cell_ops - ] + op.ram_cost + self.cell_tree[str((blk_idx, res_idx))].preprocess.ram_cost + for op in self.cell_tree[str((blk_idx, res_idx))].op.ops[: self.num_cell_ops] ] ) @@ -885,10 +788,7 @@ def __init__( .requires_grad_() ) self.log_alpha_a = nn.Parameter( - torch.zeros(self.num_blocks, len(self.arch_code2out)) - .normal_(0, 0.01) - .to(self.device) - .requires_grad_() + torch.zeros(self.num_blocks, len(self.arch_code2out)).normal_(0, 0.01).to(self.device).requires_grad_() ) self._arch_param_names = ["log_alpha_a", "log_alpha_c"] @@ -918,10 +818,7 @@ def gen_mtx(self, depth: int): # convert path activation [1,paths] to path activation matrix [depth, depth] ma = np.zeros((depth, depth)) for i in range(paths): - ma[ - (i + 1) // Cell.DIRECTIONS, - (i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS, - ] = m[i] + ma[(i + 1) // Cell.DIRECTIONS, (i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS] = m[i] mtx.append(ma) # define all possible node activation @@ -929,21 +826,13 @@ def gen_mtx(self, depth: int): transfer_mtx = {} for arch_code in node_act_list: # make sure each activated node has an active connection, inactivated node has no connection - arch_code_mtx = [ - _ - for _ in mtx - if ((np.sum(_, 0) > 0).astype(int) == np.array(arch_code)).all() - ] + arch_code_mtx = [_ for _ in mtx if ((np.sum(_, 0) > 0).astype(int) == np.array(arch_code)).all()] transfer_mtx[str(np.array(arch_code))] = arch_code_mtx return transfer_mtx, node_act_list, all_connect[1:] def weight_parameters(self): - return [ - param - for name, param in self.named_parameters() - if name not in self._arch_param_names - ] + return [param for name, param in self.named_parameters() if name not in self._arch_param_names] def get_prob_a(self, child: bool = False): """ @@ -990,14 +879,8 @@ def get_ram_cost_usage(self, in_size, full: bool = False): image_size = np.array(in_size[-self._spatial_dims :]) sizes = [] for res_idx in range(self.num_depths): - sizes.append( - batch_size - * self.filter_nums[res_idx] - * (image_size // (2 ** res_idx)).prod() - ) - sizes = torch.tensor(sizes).to(torch.float32).to(self.device) / ( - 2 ** (int(self.use_downsample)) - ) + sizes.append(batch_size * self.filter_nums[res_idx] * (image_size // (2**res_idx)).prod()) + sizes = torch.tensor(sizes).to(torch.float32).to(self.device) / (2 ** (int(self.use_downsample))) probs_a, arch_code_prob_a = self.get_prob_a(child=False) cell_prob = F.softmax(self.log_alpha_c, dim=-1) if full: @@ -1011,15 +894,10 @@ def get_ram_cost_usage(self, in_size, full: bool = False): for path_idx in range(len(self.arch_code2out)): usage += ( arch_code_prob_a[blk_idx, path_idx] - * ( - 1 - + ( - ram_cost[blk_idx, path_idx] * cell_prob[blk_idx, path_idx] - ).sum() - ) + * (1 + (ram_cost[blk_idx, path_idx] * cell_prob[blk_idx, path_idx]).sum()) * sizes[self.arch_code2out[path_idx]] ) - return usage * 32 / 8 / 1024 ** 2 + return usage * 32 / 8 / 1024**2 def get_topology_entropy(self, probs): """ @@ -1037,16 +915,10 @@ def get_topology_entropy(self, probs): # node activation index to feasible output child_idx node2out = [[] for _ in range(len(self.node_act_list))] for child_idx in range(len(self.child_list)): - _node_in, _node_out = np.zeros(self.num_depths), np.zeros( - self.num_depths - ) + _node_in, _node_out = np.zeros(self.num_depths), np.zeros(self.num_depths) for res_idx in range(len(self.arch_code2out)): - _node_out[self.arch_code2out[res_idx]] += self.child_list[ - child_idx - ][res_idx] - _node_in[self.arch_code2in[res_idx]] += self.child_list[child_idx][ - res_idx - ] + _node_out[self.arch_code2out[res_idx]] += self.child_list[child_idx][res_idx] + _node_in[self.arch_code2in[res_idx]] += self.child_list[child_idx][res_idx] _node_in = (_node_in >= 1).astype(int) _node_out = (_node_out >= 1).astype(int) node2in[self.node_act_dict[str(_node_out)]].append(child_idx) @@ -1061,10 +933,7 @@ def get_topology_entropy(self, probs): for node_idx in range(len(self.node_act_list)): _node_p = probs[blk_idx, node2in[node_idx]].sum() _out_probs = probs[blk_idx + 1, node2out[node_idx]].sum() - blk_ent += -( - _node_p * torch.log(_out_probs + 1e-5) - + (1 - _node_p) * torch.log(1 - _out_probs + 1e-5) - ) + blk_ent += -(_node_p * torch.log(_out_probs + 1e-5) + (1 - _node_p) * torch.log(1 - _out_probs + 1e-5)) ent += blk_ent return ent @@ -1086,17 +955,12 @@ def decode(self): """ probs, arch_code_prob_a = self.get_prob_a(child=True) arch_code_a_max = self.child_list[torch.argmax(probs, -1).data.cpu().numpy()] - arch_code_c = ( - torch.argmax(F.softmax(self.log_alpha_c, -1), -1).data.cpu().numpy() - ) + arch_code_c = torch.argmax(F.softmax(self.log_alpha_c, -1), -1).data.cpu().numpy() probs = probs.data.cpu().numpy() # define adjacency matrix amtx = np.zeros( - ( - 1 + len(self.child_list) * self.num_blocks + 1, - 1 + len(self.child_list) * self.num_blocks + 1, - ) + (1 + len(self.child_list) * self.num_blocks + 1, 1 + len(self.child_list) * self.num_blocks + 1) ) # build a path activation to child index searching dictionary @@ -1107,14 +971,10 @@ def decode(self): for child_idx in range(len(self.child_list)): _node_act = np.zeros(self.num_depths).astype(int) for path_idx in range(len(self.child_list[child_idx])): - _node_act[self.arch_code2out[path_idx]] += self.child_list[child_idx][ - path_idx - ] + _node_act[self.arch_code2out[path_idx]] += self.child_list[child_idx][path_idx] _node_act = (_node_act >= 1).astype(int) for mtx in self.transfer_mtx[str(_node_act)]: - connect_child_idx = path2child[ - str(mtx.flatten()[self.tidx].astype(int)) - ] + connect_child_idx = path2child[str(mtx.flatten()[self.tidx].astype(int))] sub_amtx[child_idx, connect_child_idx] = 1 # fill in source to first block, add 1e-5/1e-3 to avoid log0 and negative edge weights @@ -1123,31 +983,16 @@ def decode(self): # fill in the rest blocks for blk_idx in range(1, self.num_blocks): amtx[ - 1 - + (blk_idx - 1) * len(self.child_list) : 1 - + blk_idx * len(self.child_list), - 1 - + blk_idx * len(self.child_list) : 1 - + (blk_idx + 1) * len(self.child_list), - ] = sub_amtx * np.tile( - -np.log(probs[blk_idx] + 1e-5) + 0.001, (len(self.child_list), 1) - ) + 1 + (blk_idx - 1) * len(self.child_list) : 1 + blk_idx * len(self.child_list), + 1 + blk_idx * len(self.child_list) : 1 + (blk_idx + 1) * len(self.child_list), + ] = sub_amtx * np.tile(-np.log(probs[blk_idx] + 1e-5) + 0.001, (len(self.child_list), 1)) # fill in the last to the sink - amtx[ - 1 - + (self.num_blocks - 1) * len(self.child_list) : 1 - + self.num_blocks * len(self.child_list), - -1, - ] = 0.001 + amtx[1 + (self.num_blocks - 1) * len(self.child_list) : 1 + self.num_blocks * len(self.child_list), -1] = 0.001 graph = csr_matrix(amtx) dist_matrix, predecessors, sources = dijkstra( - csgraph=graph, - directed=True, - indices=0, - min_only=True, - return_predecessors=True, + csgraph=graph, directed=True, indices=0, min_only=True, return_predecessors=True ) index, a_idx = -1, -1 arch_code_a = np.zeros((self.num_blocks, len(self.arch_code2out))) @@ -1161,9 +1006,7 @@ def decode(self): child_idx = (index - 1) % len(self.child_list) arch_code_a[a_idx, :] = self.child_list[child_idx] for res_idx in range(len(self.arch_code2out)): - node_a[a_idx, self.arch_code2out[res_idx]] += arch_code_a[ - a_idx, res_idx - ] + node_a[a_idx, self.arch_code2out[res_idx]] += arch_code_a[a_idx, res_idx] a_idx -= 1 for res_idx in range(len(self.arch_code2out)): node_a[a_idx, self.arch_code2in[res_idx]] += arch_code_a[0, res_idx] @@ -1183,15 +1026,11 @@ def forward(self, x): inputs = x for blk_idx in range(self.num_blocks): outputs = [0.0] * self.num_depths - for res_idx, activation in enumerate( - self.arch_code_a[blk_idx].data.cpu().numpy() - ): + for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data.cpu().numpy()): if activation: _w = F.softmax(self.log_alpha_c[blk_idx, res_idx], dim=-1) outputs[self.arch_code2out[res_idx]] += ( - self.cell_tree[str((blk_idx, res_idx))]( - inputs[self.arch_code2in[res_idx]], weight=_w - ) + self.cell_tree[str((blk_idx, res_idx))](inputs[self.arch_code2in[res_idx]], weight=_w) * arch_code_prob_a[blk_idx, res_idx] ) inputs = outputs From 167f65ff61d4210fee7996f25393c87628a64888 Mon Sep 17 00:00:00 2001 From: dongy Date: Fri, 22 Apr 2022 09:25:22 -0700 Subject: [PATCH 3/3] update test case Signed-off-by: dongy --- monai/networks/nets/dints.py | 4 +++- tests/test_dints_cell.py | 40 +++++++++++++++++++++++++++++++----- tests/test_dints_network.py | 4 ++-- 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py index 842f19b01b..b7f3921a47 100644 --- a/monai/networks/nets/dints.py +++ b/monai/networks/nets/dints.py @@ -474,7 +474,9 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx - 1]), + get_norm_layer( + name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[max(res_idx - 1, 0)] + ), nn.Upsample(scale_factor=2 ** (res_idx != 0), mode=mode, align_corners=True), ) diff --git a/tests/test_dints_cell.py b/tests/test_dints_cell.py index d480235b70..a5da39bae9 100644 --- a/tests/test_dints_cell.py +++ b/tests/test_dints_cell.py @@ -32,21 +32,28 @@ (2, 4, 64, 32, 16), ], [ - {"c_prev": 8, "c": 8, "rate": 0, "arch_code_c": None}, + {"c_prev": 8, "c": 8, "rate": 0, "arch_code_c": None, "act_name": "SELU", "norm_name": "BATCH"}, torch.tensor([1, 1, 1, 1, 1]), torch.tensor([0, 0, 0, 1, 0]), (2, 8, 32, 16, 8), (2, 8, 32, 16, 8), ], [ - {"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": None}, + { + "c_prev": 8, + "c": 8, + "rate": -1, + "arch_code_c": None, + "act_name": "PRELU", + "norm_name": ("BATCH", {"affine": False}), + }, torch.tensor([1, 1, 1, 1, 1]), torch.tensor([1, 1, 1, 1, 1]), (2, 8, 32, 16, 8), (2, 8, 16, 8, 4), ], [ - {"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1]}, + {"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1], "act_name": "RELU", "norm_name": "INSTANCE"}, torch.tensor([1, 0, 0, 0, 1]), torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]), (2, 8, 32, 16, 8), @@ -56,12 +63,35 @@ TEST_CASES_2D = [ [ - {"c_prev": 8, "c": 7, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1], "spatial_dims": 2}, + { + "c_prev": 8, + "c": 7, + "rate": -1, + "arch_code_c": [1, 0, 0, 0, 1], + "spatial_dims": 2, + "act_name": "PRELU", + "norm_name": ("BATCH", {"affine": False}), + }, torch.tensor([1, 0]), torch.tensor([0.2, 0.2]), (2, 8, 16, 8), (2, 7, 8, 4), - ] + ], + [ + { + "c_prev": 8, + "c": 8, + "rate": -1, + "arch_code_c": None, + "spatial_dims": 2, + "act_name": "SELU", + "norm_name": "INSTANCE", + }, + torch.tensor([1, 0]), + torch.tensor([0.2, 0.2]), + (2, 8, 16, 8), + (2, 8, 8, 4), + ], ] diff --git a/tests/test_dints_network.py b/tests/test_dints_network.py index 8be5eb7ccd..08e75fab98 100644 --- a/tests/test_dints_network.py +++ b/tests/test_dints_network.py @@ -33,7 +33,7 @@ "in_channels": 1, "num_classes": 3, "act_name": "RELU", - "norm_name": "INSTANCE", + "norm_name": ("INSTANCE", {"affine": True}), "use_downsample": False, "spatial_dims": 3, }, @@ -101,7 +101,7 @@ "in_channels": 1, "num_classes": 4, "act_name": "RELU", - "norm_name": "INSTANCE", + "norm_name": ("INSTANCE", {"affine": True}), "use_downsample": False, "spatial_dims": 2, },