From e95e00e05cc59ba6c55fb5c7c590b1d1cbb2257a Mon Sep 17 00:00:00 2001 From: Umberto Lupo Date: Tue, 14 May 2024 12:31:46 +0200 Subject: [PATCH] Fix #5 - Add `exclude_diagonal` kwarg to the constructor and default it to `True` to keep compatibility. - Improve docstrings for IntraGroupSimilarityLoss and InterGroupSimilarityLoss --- diffpass/model.py | 32 +++++++++++++++++++++++--------- nbs/model.ipynb | 32 +++++++++++++++++++++++--------- 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/diffpass/model.py b/diffpass/model.py index a3e66bf..2704f51 100644 --- a/diffpass/model.py +++ b/diffpass/model.py @@ -542,13 +542,18 @@ class InterGroupSimilarityLoss(Module): relationships. Similarity matrices are expected to be square and symmetric. The loss is computed - by comparing the (unrolled and concatenated) upper triangular blocks containing - inter-group similarities.""" + by comparing the (flattened and concatenated) blocks containing inter-group + similarities.""" def __init__( self, *, + # Number of entries in each group (e.g. species). Groups are assumed to be + # contiguous in the input similarity matrices group_sizes: Iterable[int], + # If not ``None``, custom callable to compute the differentiable score between + # the flattened and concatenated inter-group blocks of the similarity matrices. + # Default: dot product score_fn: Union[callable, None] = None, ) -> None: super().__init__() @@ -588,17 +593,24 @@ class IntraGroupSimilarityLoss(Module): relationships. Similarity matrices are expected to be square and symmetric. Their diagonal - elements are ignored. - If `group_sizes` is provided, the loss is computed by comparing the (unrolled - and concatenated) upper triangular blocks containing intra-group similarities. + elements are ignored if `exclude_diagonal` is set to True. + If `group_sizes` is provided, the loss is computed by comparing the flattened + and concatenated upper triangular blocks containing intra-group similarities. Otherwise, the loss is computed by comparing the upper triangular part of the - full similarity matrices, excluding the main diagonal.""" + full similarity matrices.""" def __init__( self, *, + # Number of entries in each group (e.g. species). Groups are assumed to be + # contiguous in the input similarity matrices group_sizes: Optional[Iterable[int]] = None, + # If not ``None``, custom callable to compute the differentiable score between + # the flattened and concatenated intra-group blocks of the similarity matrices + # Default: dot product score_fn: Union[callable, None] = None, + # If ``True``, exclude the diagonal elements from the computation + exclude_diagonal: bool = True, ) -> None: super().__init__() self.group_sizes = ( @@ -607,15 +619,17 @@ def __init__( self.score_fn = ( partial(torch.tensordot, dims=1) if score_fn is None else score_fn ) + self.exclude_diagonal = exclude_diagonal if self.group_sizes is not None: # Boolean mask for the main diagonal blocks corresponding to groups diag_blocks_mask = torch.block_diag( *[torch.ones((s, s), dtype=torch.bool) for s in self.group_sizes] ) - # Extract the upper triangular part, excluding the main diagonal + # Extract the upper triangular part self.register_buffer( - "_upper_diag_blocks_mask", torch.triu(diag_blocks_mask, diagonal=1) + "_upper_diag_blocks_mask", + torch.triu(diag_blocks_mask, diagonal=int(self.exclude_diagonal)), ) else: self._upper_diag_blocks_mask = None @@ -638,7 +652,7 @@ def forward( layout=similarities_x.layout, device=similarities_x.device, ), - diagonal=1, + diagonal=int(self.exclude_diagonal), ) else: mask = self._upper_diag_blocks_mask diff --git a/nbs/model.ipynb b/nbs/model.ipynb index b0eff91..a434c64 100644 --- a/nbs/model.ipynb +++ b/nbs/model.ipynb @@ -1042,13 +1042,18 @@ " relationships.\n", "\n", " Similarity matrices are expected to be square and symmetric. The loss is computed\n", - " by comparing the (unrolled and concatenated) upper triangular blocks containing\n", - " inter-group similarities.\"\"\"\n", + " by comparing the (flattened and concatenated) blocks containing inter-group\n", + " similarities.\"\"\"\n", "\n", " def __init__(\n", " self,\n", " *,\n", + " # Number of entries in each group (e.g. species). Groups are assumed to be\n", + " # contiguous in the input similarity matrices\n", " group_sizes: Iterable[int],\n", + " # If not ``None``, custom callable to compute the differentiable score between\n", + " # the flattened and concatenated inter-group blocks of the similarity matrices.\n", + " # Default: dot product\n", " score_fn: Union[callable, None] = None,\n", " ) -> None:\n", " super().__init__()\n", @@ -1088,17 +1093,24 @@ " relationships.\n", "\n", " Similarity matrices are expected to be square and symmetric. Their diagonal\n", - " elements are ignored.\n", - " If `group_sizes` is provided, the loss is computed by comparing the (unrolled\n", - " and concatenated) upper triangular blocks containing intra-group similarities.\n", + " elements are ignored if `exclude_diagonal` is set to True.\n", + " If `group_sizes` is provided, the loss is computed by comparing the flattened\n", + " and concatenated upper triangular blocks containing intra-group similarities.\n", " Otherwise, the loss is computed by comparing the upper triangular part of the\n", - " full similarity matrices, excluding the main diagonal.\"\"\"\n", + " full similarity matrices.\"\"\"\n", "\n", " def __init__(\n", " self,\n", " *,\n", + " # Number of entries in each group (e.g. species). Groups are assumed to be\n", + " # contiguous in the input similarity matrices\n", " group_sizes: Optional[Iterable[int]] = None,\n", + " # If not ``None``, custom callable to compute the differentiable score between\n", + " # the flattened and concatenated intra-group blocks of the similarity matrices\n", + " # Default: dot product\n", " score_fn: Union[callable, None] = None,\n", + " # If ``True``, exclude the diagonal elements from the computation\n", + " exclude_diagonal: bool = True,\n", " ) -> None:\n", " super().__init__()\n", " self.group_sizes = (\n", @@ -1107,15 +1119,17 @@ " self.score_fn = (\n", " partial(torch.tensordot, dims=1) if score_fn is None else score_fn\n", " )\n", + " self.exclude_diagonal = exclude_diagonal\n", "\n", " if self.group_sizes is not None:\n", " # Boolean mask for the main diagonal blocks corresponding to groups\n", " diag_blocks_mask = torch.block_diag(\n", " *[torch.ones((s, s), dtype=torch.bool) for s in self.group_sizes]\n", " )\n", - " # Extract the upper triangular part, excluding the main diagonal\n", + " # Extract the upper triangular part\n", " self.register_buffer(\n", - " \"_upper_diag_blocks_mask\", torch.triu(diag_blocks_mask, diagonal=1)\n", + " \"_upper_diag_blocks_mask\",\n", + " torch.triu(diag_blocks_mask, diagonal=int(self.exclude_diagonal)),\n", " )\n", " else:\n", " self._upper_diag_blocks_mask = None\n", @@ -1138,7 +1152,7 @@ " layout=similarities_x.layout,\n", " device=similarities_x.device,\n", " ),\n", - " diagonal=1,\n", + " diagonal=int(self.exclude_diagonal),\n", " )\n", " else:\n", " mask = self._upper_diag_blocks_mask\n",