diff --git a/diffpass/base.py b/diffpass/base.py index baa14d6..064a3f5 100644 --- a/diffpass/base.py +++ b/diffpass/base.py @@ -76,14 +76,14 @@ class DiffPaSSResults: ] # Hard losses hard_losses: Union[ - GradientDescentList[GroupByGroupList[np.ndarray]], - BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]], + GradientDescentList[GroupByGroupList[float]], + BootstrapList[GradientDescentList[GroupByGroupList[float]]], ] # Soft losses soft_losses: Optional[ Union[ - GradientDescentList[GroupByGroupList[np.ndarray]], - BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]], + GradientDescentList[GroupByGroupList[float]], + BootstrapList[GradientDescentList[GroupByGroupList[float]]], ] ] @@ -318,7 +318,7 @@ def _hard_pass( for perms_this_group in perms ] ) - results.hard_losses.append(dccn(loss)) + results.hard_losses.append(loss.item()) def _soft_pass( self, @@ -338,10 +338,7 @@ def _soft_pass( [dccn(perms_this_group) for perms_this_group in perms] ) if record_soft_losses: - results.soft_losses.append(dccn(loss)) - - # Compute total loss - loss = loss.sum() + results.soft_losses.append(loss.item()) return loss diff --git a/nbs/base.ipynb b/nbs/base.ipynb index 32ccc11..a6bb344 100644 --- a/nbs/base.ipynb +++ b/nbs/base.ipynb @@ -144,14 +144,14 @@ " ]\n", " # Hard losses\n", " hard_losses: Union[\n", - " GradientDescentList[GroupByGroupList[np.ndarray]],\n", - " BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]],\n", + " GradientDescentList[GroupByGroupList[float]],\n", + " BootstrapList[GradientDescentList[GroupByGroupList[float]]],\n", " ]\n", " # Soft losses\n", " soft_losses: Optional[\n", " Union[\n", - " GradientDescentList[GroupByGroupList[np.ndarray]],\n", - " BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]],\n", + " GradientDescentList[GroupByGroupList[float]],\n", + " BootstrapList[GradientDescentList[GroupByGroupList[float]]],\n", " ]\n", " ]\n", "\n", @@ -386,7 +386,7 @@ " for perms_this_group in perms\n", " ]\n", " )\n", - " results.hard_losses.append(dccn(loss))\n", + " results.hard_losses.append(loss.item())\n", "\n", " def _soft_pass(\n", " self,\n", @@ -406,10 +406,7 @@ " [dccn(perms_this_group) for perms_this_group in perms]\n", " )\n", " if record_soft_losses:\n", - " results.soft_losses.append(dccn(loss))\n", - "\n", - " # Compute total loss\n", - " loss = loss.sum()\n", + " results.soft_losses.append(loss.item())\n", "\n", " return loss\n", "\n",