diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py index eea9f351d8..437789ef0c 100644 --- a/monai/networks/nets/dints.py +++ b/monai/networks/nets/dints.py @@ -40,7 +40,7 @@ class CellInterface(torch.nn.Module): """interface for torchscriptable Cell""" - def forward(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: # type: ignore + def forward(self, x: torch.Tensor, weight) -> torch.Tensor: # type: ignore pass @@ -170,7 +170,7 @@ def __init__(self, c: int, ops: dict, arch_code_c=None): if arch_c > 0: self.ops.append(ops[op_name](c)) - def forward(self, x: torch.Tensor, weight: torch.Tensor): + def forward(self, x: torch.Tensor, weight: torch.Tensor | None = None): """ Args: x: input tensor. @@ -179,9 +179,10 @@ def forward(self, x: torch.Tensor, weight: torch.Tensor): out: weighted average of the operation results. """ out = 0.0 - weight = weight.to(x) + if weight is not None: + weight = weight.to(x) for idx, _op in enumerate(self.ops): - out = out + _op(x) * weight[idx] + out = (out + _op(x)) if weight is None else out + _op(x) * weight[idx] return out @@ -297,7 +298,7 @@ def __init__( self.op = MixedOp(c, self.OPS, arch_code_c) - def forward(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, weight: torch.Tensor | None) -> torch.Tensor: """ Args: x: input tensor @@ -669,15 +670,13 @@ def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]: x: input tensor. """ # generate path activation probability - inputs, outputs = x, [torch.tensor(0.0).to(x[0])] * self.num_depths + inputs = x for blk_idx in range(self.num_blocks): - outputs = [torch.tensor(0.0).to(x[0])] * self.num_depths + outputs = [torch.tensor(0.0, dtype=x[0].dtype, device=x[0].device)] * self.num_depths for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data): if activation: mod: CellInterface = self.cell_tree[str((blk_idx, res_idx))] - _out = mod.forward( - x=inputs[self.arch_code2in[res_idx]], weight=torch.ones_like(self.arch_code_c[blk_idx, res_idx]) - ) + _out = mod.forward(x=inputs[self.arch_code2in[res_idx]], weight=None) outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out inputs = outputs @@ -885,13 +884,13 @@ def get_ram_cost_usage(self, in_size, full: bool = False): sizes = [] for res_idx in range(self.num_depths): sizes.append(batch_size * self.filter_nums[res_idx] * (image_size // (2**res_idx)).prod()) - sizes = torch.tensor(sizes).to(torch.float32).to(self.device) / (2 ** (int(self.use_downsample))) + sizes = torch.tensor(sizes, dtype=torch.float32, device=self.device) / (2 ** (int(self.use_downsample))) probs_a, arch_code_prob_a = self.get_prob_a(child=False) cell_prob = F.softmax(self.log_alpha_c, dim=-1) if full: arch_code_prob_a = arch_code_prob_a.detach() arch_code_prob_a.fill_(1) - ram_cost = torch.from_numpy(self.ram_cost).to(torch.float32).to(self.device) + ram_cost = torch.from_numpy(self.ram_cost).to(dtype=torch.float32, device=self.device) usage = 0.0 for blk_idx in range(self.num_blocks): # node activation for input