Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions diffpass/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,13 +542,18 @@ class InterGroupSimilarityLoss(Module):
relationships.

Similarity matrices are expected to be square and symmetric. The loss is computed
by comparing the (unrolled and concatenated) upper triangular blocks containing
inter-group similarities."""
by comparing the (flattened and concatenated) blocks containing inter-group
similarities."""

def __init__(
self,
*,
# Number of entries in each group (e.g. species). Groups are assumed to be
# contiguous in the input similarity matrices
group_sizes: Iterable[int],
# If not ``None``, custom callable to compute the differentiable score between
# the flattened and concatenated inter-group blocks of the similarity matrices.
# Default: dot product
score_fn: Union[callable, None] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -588,17 +593,24 @@ class IntraGroupSimilarityLoss(Module):
relationships.

Similarity matrices are expected to be square and symmetric. Their diagonal
elements are ignored.
If `group_sizes` is provided, the loss is computed by comparing the (unrolled
and concatenated) upper triangular blocks containing intra-group similarities.
elements are ignored if `exclude_diagonal` is set to True.
If `group_sizes` is provided, the loss is computed by comparing the flattened
and concatenated upper triangular blocks containing intra-group similarities.
Otherwise, the loss is computed by comparing the upper triangular part of the
full similarity matrices, excluding the main diagonal."""
full similarity matrices."""

def __init__(
self,
*,
# Number of entries in each group (e.g. species). Groups are assumed to be
# contiguous in the input similarity matrices
group_sizes: Optional[Iterable[int]] = None,
# If not ``None``, custom callable to compute the differentiable score between
# the flattened and concatenated intra-group blocks of the similarity matrices
# Default: dot product
score_fn: Union[callable, None] = None,
# If ``True``, exclude the diagonal elements from the computation
exclude_diagonal: bool = True,
) -> None:
super().__init__()
self.group_sizes = (
Expand All @@ -607,15 +619,17 @@ def __init__(
self.score_fn = (
partial(torch.tensordot, dims=1) if score_fn is None else score_fn
)
self.exclude_diagonal = exclude_diagonal

if self.group_sizes is not None:
# Boolean mask for the main diagonal blocks corresponding to groups
diag_blocks_mask = torch.block_diag(
*[torch.ones((s, s), dtype=torch.bool) for s in self.group_sizes]
)
# Extract the upper triangular part, excluding the main diagonal
# Extract the upper triangular part
self.register_buffer(
"_upper_diag_blocks_mask", torch.triu(diag_blocks_mask, diagonal=1)
"_upper_diag_blocks_mask",
torch.triu(diag_blocks_mask, diagonal=int(self.exclude_diagonal)),
)
else:
self._upper_diag_blocks_mask = None
Expand All @@ -638,7 +652,7 @@ def forward(
layout=similarities_x.layout,
device=similarities_x.device,
),
diagonal=1,
diagonal=int(self.exclude_diagonal),
)
else:
mask = self._upper_diag_blocks_mask
Expand Down
32 changes: 23 additions & 9 deletions nbs/model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1042,13 +1042,18 @@
" relationships.\n",
"\n",
" Similarity matrices are expected to be square and symmetric. The loss is computed\n",
" by comparing the (unrolled and concatenated) upper triangular blocks containing\n",
" inter-group similarities.\"\"\"\n",
" by comparing the (flattened and concatenated) blocks containing inter-group\n",
" similarities.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" *,\n",
" # Number of entries in each group (e.g. species). Groups are assumed to be\n",
" # contiguous in the input similarity matrices\n",
" group_sizes: Iterable[int],\n",
" # If not ``None``, custom callable to compute the differentiable score between\n",
" # the flattened and concatenated inter-group blocks of the similarity matrices.\n",
" # Default: dot product\n",
" score_fn: Union[callable, None] = None,\n",
" ) -> None:\n",
" super().__init__()\n",
Expand Down Expand Up @@ -1088,17 +1093,24 @@
" relationships.\n",
"\n",
" Similarity matrices are expected to be square and symmetric. Their diagonal\n",
" elements are ignored.\n",
" If `group_sizes` is provided, the loss is computed by comparing the (unrolled\n",
" and concatenated) upper triangular blocks containing intra-group similarities.\n",
" elements are ignored if `exclude_diagonal` is set to True.\n",
" If `group_sizes` is provided, the loss is computed by comparing the flattened\n",
" and concatenated upper triangular blocks containing intra-group similarities.\n",
" Otherwise, the loss is computed by comparing the upper triangular part of the\n",
" full similarity matrices, excluding the main diagonal.\"\"\"\n",
" full similarity matrices.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" *,\n",
" # Number of entries in each group (e.g. species). Groups are assumed to be\n",
" # contiguous in the input similarity matrices\n",
" group_sizes: Optional[Iterable[int]] = None,\n",
" # If not ``None``, custom callable to compute the differentiable score between\n",
" # the flattened and concatenated intra-group blocks of the similarity matrices\n",
" # Default: dot product\n",
" score_fn: Union[callable, None] = None,\n",
" # If ``True``, exclude the diagonal elements from the computation\n",
" exclude_diagonal: bool = True,\n",
" ) -> None:\n",
" super().__init__()\n",
" self.group_sizes = (\n",
Expand All @@ -1107,15 +1119,17 @@
" self.score_fn = (\n",
" partial(torch.tensordot, dims=1) if score_fn is None else score_fn\n",
" )\n",
" self.exclude_diagonal = exclude_diagonal\n",
"\n",
" if self.group_sizes is not None:\n",
" # Boolean mask for the main diagonal blocks corresponding to groups\n",
" diag_blocks_mask = torch.block_diag(\n",
" *[torch.ones((s, s), dtype=torch.bool) for s in self.group_sizes]\n",
" )\n",
" # Extract the upper triangular part, excluding the main diagonal\n",
" # Extract the upper triangular part\n",
" self.register_buffer(\n",
" \"_upper_diag_blocks_mask\", torch.triu(diag_blocks_mask, diagonal=1)\n",
" \"_upper_diag_blocks_mask\",\n",
" torch.triu(diag_blocks_mask, diagonal=int(self.exclude_diagonal)),\n",
" )\n",
" else:\n",
" self._upper_diag_blocks_mask = None\n",
Expand All @@ -1138,7 +1152,7 @@
" layout=similarities_x.layout,\n",
" device=similarities_x.device,\n",
" ),\n",
" diagonal=1,\n",
" diagonal=int(self.exclude_diagonal),\n",
" )\n",
" else:\n",
" mask = self._upper_diag_blocks_mask\n",
Expand Down