Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions diffpass/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ class DiffPaSSResults:
]
# Hard losses
hard_losses: Union[
GradientDescentList[GroupByGroupList[np.ndarray]],
BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]],
GradientDescentList[GroupByGroupList[float]],
BootstrapList[GradientDescentList[GroupByGroupList[float]]],
]
# Soft losses
soft_losses: Optional[
Union[
GradientDescentList[GroupByGroupList[np.ndarray]],
BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]],
GradientDescentList[GroupByGroupList[float]],
BootstrapList[GradientDescentList[GroupByGroupList[float]]],
]
]

Expand Down Expand Up @@ -318,7 +318,7 @@ def _hard_pass(
for perms_this_group in perms
]
)
results.hard_losses.append(dccn(loss))
results.hard_losses.append(loss.item())

def _soft_pass(
self,
Expand All @@ -338,10 +338,7 @@ def _soft_pass(
[dccn(perms_this_group) for perms_this_group in perms]
)
if record_soft_losses:
results.soft_losses.append(dccn(loss))

# Compute total loss
loss = loss.sum()
results.soft_losses.append(loss.item())

return loss

Expand Down
15 changes: 6 additions & 9 deletions nbs/base.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,14 @@
" ]\n",
" # Hard losses\n",
" hard_losses: Union[\n",
" GradientDescentList[GroupByGroupList[np.ndarray]],\n",
" BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]],\n",
" GradientDescentList[GroupByGroupList[float]],\n",
" BootstrapList[GradientDescentList[GroupByGroupList[float]]],\n",
" ]\n",
" # Soft losses\n",
" soft_losses: Optional[\n",
" Union[\n",
" GradientDescentList[GroupByGroupList[np.ndarray]],\n",
" BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]],\n",
" GradientDescentList[GroupByGroupList[float]],\n",
" BootstrapList[GradientDescentList[GroupByGroupList[float]]],\n",
" ]\n",
" ]\n",
"\n",
Expand Down Expand Up @@ -386,7 +386,7 @@
" for perms_this_group in perms\n",
" ]\n",
" )\n",
" results.hard_losses.append(dccn(loss))\n",
" results.hard_losses.append(loss.item())\n",
"\n",
" def _soft_pass(\n",
" self,\n",
Expand All @@ -406,10 +406,7 @@
" [dccn(perms_this_group) for perms_this_group in perms]\n",
" )\n",
" if record_soft_losses:\n",
" results.soft_losses.append(dccn(loss))\n",
"\n",
" # Compute total loss\n",
" loss = loss.sum()\n",
" results.soft_losses.append(loss.item())\n",
"\n",
" return loss\n",
"\n",
Expand Down