From cdced15c287d453b39cc57f09ae65e849f5414af Mon Sep 17 00:00:00 2001 From: Umberto Lupo Date: Wed, 15 May 2024 13:48:21 +0200 Subject: [PATCH] Fix #9 - Add `n_repeats` kwarg to `DiffPaSSModel.fit_bootstrap` - Tidy up `fit_bootstrap` code --- diffpass/base.py | 191 ++++++++++++++++++++++++++++++++++------------ diffpass/model.py | 2 +- nbs/base.ipynb | 191 ++++++++++++++++++++++++++++++++++------------ nbs/model.ipynb | 2 +- 4 files changed, 288 insertions(+), 98 deletions(-) diff --git a/diffpass/base.py b/diffpass/base.py index 8f9164e..b4ee242 100644 --- a/diffpass/base.py +++ b/diffpass/base.py @@ -514,6 +514,7 @@ def fit_bootstrap( int ] = None, # If ``None``, the bootstrap will end when all pairs are fixed. Otherwise, the bootstrap will end when `n_end` pairs are fixed step_size: int = 1, # Difference between the number of fixed pairings chosen at consecutive bootstrap iterations + n_repeats: int = 1, # At each bootstrap iteration, `n_repeats` runs will be performed, and the run with the lowest loss will be chosen show_pbar: bool = True, # If ``True``, show progress bar. Default: ``True`` single_fit_cfg: Optional[ dict @@ -528,72 +529,60 @@ def fit_bootstrap( and fixed for the next run. The number of pairings fixed at each iteration ranges between `n_start` (default: 1) and `n_end` (default: total number of pairs), with a step size of `step_size`. """ - self.prepare_fit(x, y) - if self.fixed_pairings is None: - initial_fixed_pairings = [[] for _ in self.group_sizes] - else: - initial_fixed_pairings = [list(fm) for fm in self.fixed_pairings] - - _single_fit_cfg = deepcopy(self.single_fit_default_cfg) - if single_fit_cfg is not None: - _single_fit_cfg.update(single_fit_cfg) - single_fit_cfg = _single_fit_cfg + ########## Preparations ########## - # Initialize DiffPaSSResults object - results = self._init_results( - record_log_alphas=single_fit_cfg["record_log_alphas"], - record_soft_perms=single_fit_cfg["record_soft_perms"], - record_soft_losses=single_fit_cfg["record_soft_losses"], - ) - available_fields = [ - field.name - for field in fields(results) - if getattr(results, field.name) is not None - ] - field_to_length_so_far = {field_name: 0 for field_name in available_fields} + # Input validation + self.prepare_fit(x, y) + # Prepare variables for indexing n_samples = len(x) n_groups = len(self.group_sizes) cumsum_group_sizes = np.cumsum([0] + list(self.group_sizes)) offsets = np.repeat(cumsum_group_sizes[:-1], repeats=self.group_sizes) group_idxs = np.repeat(np.arange(n_groups), repeats=self.group_sizes) - # First fit with initial fixed matchings - can_optimize = self._fit(x, y, results=results, **single_fit_cfg) + # Initially fixed pairings as derived from the `fixed_pairings` attribute + if self.fixed_pairings is None: + initially_fixed_pairings = [[] for _ in self.group_sizes] + else: + initially_fixed_pairings = [list(fm) for fm in self.fixed_pairings] - # Find effective initial fixed matchings - effective_initial_fixed_idxs = [] + # *Effective* initially fixed pairings as global indices (not relative to group) + # Used to exclude these pairs from the random sampling of new fixed pairings + # and to determine when the bootstrap will end + effective_initially_fixed_idxs = [] for s, efmz in zip( cumsum_group_sizes, self.permutation._effective_fixed_pairings_zip ): if efmz: - effective_initial_fixed_idxs += [ - (s + efmz_fixed) for efmz_fixed in efmz[1] + effective_initially_fixed_idxs += [ + s + efmz_fixed for efmz_fixed in efmz[1] ] - effective_initial_fixed_idxs = np.asarray(effective_initial_fixed_idxs) - nonfixed_idxs = np.setdiff1d(np.arange(n_samples), effective_initial_fixed_idxs) - n_effective_initial_fixed_pairings = len(effective_initial_fixed_idxs) - + non_initially_fixed_idxs = np.setdiff1d( + np.arange(n_samples), effective_initially_fixed_idxs + ) if n_end is None: - n_end = n_samples - n_effective_initial_fixed_pairings - 1 - - # Subsequent fits: at a given iteration we use fixed matchings chosen uniformly at - # random from the results of the previous iteration (excluding the effective initial - # fixed matchings) - pbar = list(range(n_start, n_end, step_size)) + n_end = n_samples - len(effective_initially_fixed_idxs) - 1 + # Bootstrap range and progress bar + pbar = range(n_start, n_end, step_size) pbar = tqdm(pbar) if show_pbar else pbar - n_iters_with_optimization = int(can_optimize) - for N in pbar: - latest_hard_perms = results.hard_perms[-1] - mapped_idxs = offsets + np.concatenate(latest_hard_perms) - rand_fixed_idxs = np.random.permutation(nonfixed_idxs)[:N] + + ########## End preparations ########## + + ########## Closures ########## + + def make_new_fixed_pairings( + mapped_idxs: np.ndarray, N: int + ) -> IndexPairsInGroups: + """Subroutine for randomly sampling new fixed pairings for the next bootstrap iteration.""" + rand_fixed_idxs = np.random.permutation(non_initially_fixed_idxs)[:N] rand_fixed_idxs = np.sort(rand_fixed_idxs) rand_mapped_idxs = mapped_idxs[rand_fixed_idxs] rand_group_idxs = group_idxs[rand_fixed_idxs] rand_fixed_rel_idxs = rand_fixed_idxs - offsets[rand_fixed_idxs] rand_mapped_rel_idxs = rand_mapped_idxs - offsets[rand_mapped_idxs] - # Update fixed matchings + # Update fixed pairings fixed_pairings = [[] for _ in range(n_groups)] for rand_group_idx, mapped_rel_idx, fixed_rel_idx in zip( rand_group_idxs, rand_mapped_rel_idxs, rand_fixed_rel_idxs @@ -601,22 +590,126 @@ def fit_bootstrap( pair = (mapped_rel_idx, fixed_rel_idx) fixed_pairings[rand_group_idx].append(pair) fixed_pairings = [ - initial_fixed_pairings[k] + fixed_pairings[k] for k in range(n_groups) + initially_fixed_pairings[k] + fixed_pairings[k] for k in range(n_groups) ] - self.permutation.init_fixed_pairings_and_log_alphas( - fixed_pairings, device=x.device + + return fixed_pairings + + def init_diffpassresults() -> DiffPaSSResults: + return self._init_results( + record_log_alphas=single_fit_cfg["record_log_alphas"], + record_soft_perms=single_fit_cfg["record_soft_perms"], + record_soft_losses=single_fit_cfg["record_soft_losses"], ) + def extend_results_with_lowest_loss_repeat( + results_this_iter: DiffPaSSResults, + results: DiffPaSSResults, + can_optimize: bool, + ) -> None: + """Extend the global optimization object `results` with the portion of + `results_this_iter` (from the latest bootstrap iteration) corresponding to + the repeat with the lowest hard loss.""" + if can_optimize: + # Select run with lowest hard loss, discard the rest + reshaped_hard_losses_this_repeat = np.asarray( + results_this_iter.hard_losses + ).reshape(n_repeats, -1) + min_loss_idx = np.argmin(reshaped_hard_losses_this_repeat[:, -1]) + size_each_repeat = reshaped_hard_losses_this_repeat.shape[1] + # Record complete results of the run with the lowest loss + slice_to_append = slice( + min_loss_idx * size_each_repeat, + (min_loss_idx + 1) * size_each_repeat, + ) + else: + slice_to_append = slice(None) + [ + getattr(results, field_name).extend( + getattr(results_this_iter, field_name)[slice_to_append] + ) + for field_name in available_fields + ] + + postprocess_results_after_repeats = ( + extend_results_with_lowest_loss_repeat + if n_repeats > 1 + else lambda *args: None + ) + + ########## End closures ########## + + # Configuration for each gradient descent run + _single_fit_cfg = deepcopy(self.single_fit_default_cfg) + if single_fit_cfg is not None: + _single_fit_cfg.update(single_fit_cfg) + single_fit_cfg = _single_fit_cfg + + # Initialize DiffPaSSResults object + results = init_diffpassresults() + available_fields = [ + field.name + for field in fields(results) + if getattr(results, field.name) is not None + ] + field_to_length_so_far = {field_name: 0 for field_name in available_fields} + + ########## Optimization ########## + + # First fit with initially fixed pairings + can_optimize = self._fit(x, y, results=results, **single_fit_cfg) + n_iters_with_optimization = int(can_optimize) + + # DiffPaSSResults object for each bootstrap iteration: + # new object if `n_repeats` > 1, else the existing `results` + get_results_to_use_in_each_bootstrap_iter = ( + init_diffpassresults if n_repeats > 1 else lambda: results + ) + + # Subsequent bootstrap fits: at a given iteration we use fixed pairings chosen uniformly at + # random from the results of the previous iteration (excluding the effective initially + # fixed pairings) + for N in pbar: + latest_hard_perms = results.hard_perms[-1] + mapped_idxs = offsets + np.concatenate(latest_hard_perms) + field_to_length_so_far = { field_name: len(getattr(results, field_name)) for field_name in available_fields } - can_optimize = self._fit(x, y, results=results, **single_fit_cfg) + + results_this_iter = ( + get_results_to_use_in_each_bootstrap_iter() + ) # `results` alias if `n_repeats` == 1 + for _ in range(n_repeats): + # Randomly sample N fixed pairings + fixed_pairings = make_new_fixed_pairings(mapped_idxs, N) + # Reinitialize permutation module with new fixed pairings + self.permutation.init_fixed_pairings_and_log_alphas( + fixed_pairings, device=x.device + ) + # Fit with gradient descent + can_optimize = self._fit( + x, y, results=results_this_iter, **single_fit_cfg + ) + if not can_optimize: + # If we can't fit, we break the "repeats" loop + break + + postprocess_results_after_repeats( + results_this_iter, results, can_optimize + ) # Does nothing if `n_repeats` == 1 + if can_optimize: n_iters_with_optimization += 1 else: + # If we could not fit, terminate the bootstrap break + ########## End optimization ########## + + ########## Post-processing ########## + # Reshape results according to number of iterations performed reshaped_fields = {} for field_name in available_fields: @@ -644,4 +737,6 @@ def fit_bootstrap( ) results = replace(results, **reshaped_fields) + ########## End post-processing ########## + return results diff --git a/diffpass/model.py b/diffpass/model.py index a0de76c..e423744 100644 --- a/diffpass/model.py +++ b/diffpass/model.py @@ -194,7 +194,7 @@ def hard_(self) -> None: self.mode = "hard" def _impl_fixed_pairings(self, func: callable) -> callable: - """Include fixed matchings in the Gumbel-Sinkhorn or Gumbel-matching operators.""" + """Include fixed pairings in the Gumbel-Sinkhorn or Gumbel-matching operators.""" def wrapper(gen: Iterator[torch.Tensor]) -> Iterator[torch.Tensor]: for s, mat, (row_group, col_group), mask in zip( diff --git a/nbs/base.ipynb b/nbs/base.ipynb index 4ed9cb3..933a537 100644 --- a/nbs/base.ipynb +++ b/nbs/base.ipynb @@ -582,6 +582,7 @@ " int\n", " ] = None, # If ``None``, the bootstrap will end when all pairs are fixed. Otherwise, the bootstrap will end when `n_end` pairs are fixed\n", " step_size: int = 1, # Difference between the number of fixed pairings chosen at consecutive bootstrap iterations\n", + " n_repeats: int = 1, # At each bootstrap iteration, `n_repeats` runs will be performed, and the run with the lowest loss will be chosen\n", " show_pbar: bool = True, # If ``True``, show progress bar. Default: ``True``\n", " single_fit_cfg: Optional[\n", " dict\n", @@ -596,72 +597,60 @@ " and fixed for the next run.\n", " The number of pairings fixed at each iteration ranges between `n_start` (default: 1) and `n_end` (default: total number of pairs), with a step size of `step_size`.\n", " \"\"\"\n", - " self.prepare_fit(x, y)\n", - " if self.fixed_pairings is None:\n", - " initial_fixed_pairings = [[] for _ in self.group_sizes]\n", - " else:\n", - " initial_fixed_pairings = [list(fm) for fm in self.fixed_pairings]\n", - "\n", - " _single_fit_cfg = deepcopy(self.single_fit_default_cfg)\n", - " if single_fit_cfg is not None:\n", - " _single_fit_cfg.update(single_fit_cfg)\n", - " single_fit_cfg = _single_fit_cfg\n", + " ########## Preparations ##########\n", "\n", - " # Initialize DiffPaSSResults object\n", - " results = self._init_results(\n", - " record_log_alphas=single_fit_cfg[\"record_log_alphas\"],\n", - " record_soft_perms=single_fit_cfg[\"record_soft_perms\"],\n", - " record_soft_losses=single_fit_cfg[\"record_soft_losses\"],\n", - " )\n", - " available_fields = [\n", - " field.name\n", - " for field in fields(results)\n", - " if getattr(results, field.name) is not None\n", - " ]\n", - " field_to_length_so_far = {field_name: 0 for field_name in available_fields}\n", + " # Input validation\n", + " self.prepare_fit(x, y)\n", "\n", + " # Prepare variables for indexing\n", " n_samples = len(x)\n", " n_groups = len(self.group_sizes)\n", " cumsum_group_sizes = np.cumsum([0] + list(self.group_sizes))\n", " offsets = np.repeat(cumsum_group_sizes[:-1], repeats=self.group_sizes)\n", " group_idxs = np.repeat(np.arange(n_groups), repeats=self.group_sizes)\n", "\n", - " # First fit with initial fixed matchings\n", - " can_optimize = self._fit(x, y, results=results, **single_fit_cfg)\n", + " # Initially fixed pairings as derived from the `fixed_pairings` attribute\n", + " if self.fixed_pairings is None:\n", + " initially_fixed_pairings = [[] for _ in self.group_sizes]\n", + " else:\n", + " initially_fixed_pairings = [list(fm) for fm in self.fixed_pairings]\n", "\n", - " # Find effective initial fixed matchings\n", - " effective_initial_fixed_idxs = []\n", + " # *Effective* initially fixed pairings as global indices (not relative to group)\n", + " # Used to exclude these pairs from the random sampling of new fixed pairings\n", + " # and to determine when the bootstrap will end\n", + " effective_initially_fixed_idxs = []\n", " for s, efmz in zip(\n", " cumsum_group_sizes, self.permutation._effective_fixed_pairings_zip\n", " ):\n", " if efmz:\n", - " effective_initial_fixed_idxs += [\n", - " (s + efmz_fixed) for efmz_fixed in efmz[1]\n", + " effective_initially_fixed_idxs += [\n", + " s + efmz_fixed for efmz_fixed in efmz[1]\n", " ]\n", - " effective_initial_fixed_idxs = np.asarray(effective_initial_fixed_idxs)\n", - " nonfixed_idxs = np.setdiff1d(np.arange(n_samples), effective_initial_fixed_idxs)\n", - " n_effective_initial_fixed_pairings = len(effective_initial_fixed_idxs)\n", - "\n", + " non_initially_fixed_idxs = np.setdiff1d(\n", + " np.arange(n_samples), effective_initially_fixed_idxs\n", + " )\n", " if n_end is None:\n", - " n_end = n_samples - n_effective_initial_fixed_pairings - 1\n", - "\n", - " # Subsequent fits: at a given iteration we use fixed matchings chosen uniformly at\n", - " # random from the results of the previous iteration (excluding the effective initial\n", - " # fixed matchings)\n", - " pbar = list(range(n_start, n_end, step_size))\n", + " n_end = n_samples - len(effective_initially_fixed_idxs) - 1\n", + " # Bootstrap range and progress bar\n", + " pbar = range(n_start, n_end, step_size)\n", " pbar = tqdm(pbar) if show_pbar else pbar\n", - " n_iters_with_optimization = int(can_optimize)\n", - " for N in pbar:\n", - " latest_hard_perms = results.hard_perms[-1]\n", - " mapped_idxs = offsets + np.concatenate(latest_hard_perms)\n", - " rand_fixed_idxs = np.random.permutation(nonfixed_idxs)[:N]\n", + "\n", + " ########## End preparations ##########\n", + "\n", + " ########## Closures ##########\n", + "\n", + " def make_new_fixed_pairings(\n", + " mapped_idxs: np.ndarray, N: int\n", + " ) -> IndexPairsInGroups:\n", + " \"\"\"Subroutine for randomly sampling new fixed pairings for the next bootstrap iteration.\"\"\"\n", + " rand_fixed_idxs = np.random.permutation(non_initially_fixed_idxs)[:N]\n", " rand_fixed_idxs = np.sort(rand_fixed_idxs)\n", " rand_mapped_idxs = mapped_idxs[rand_fixed_idxs]\n", " rand_group_idxs = group_idxs[rand_fixed_idxs]\n", " rand_fixed_rel_idxs = rand_fixed_idxs - offsets[rand_fixed_idxs]\n", " rand_mapped_rel_idxs = rand_mapped_idxs - offsets[rand_mapped_idxs]\n", "\n", - " # Update fixed matchings\n", + " # Update fixed pairings\n", " fixed_pairings = [[] for _ in range(n_groups)]\n", " for rand_group_idx, mapped_rel_idx, fixed_rel_idx in zip(\n", " rand_group_idxs, rand_mapped_rel_idxs, rand_fixed_rel_idxs\n", @@ -669,22 +658,126 @@ " pair = (mapped_rel_idx, fixed_rel_idx)\n", " fixed_pairings[rand_group_idx].append(pair)\n", " fixed_pairings = [\n", - " initial_fixed_pairings[k] + fixed_pairings[k] for k in range(n_groups)\n", + " initially_fixed_pairings[k] + fixed_pairings[k] for k in range(n_groups)\n", " ]\n", - " self.permutation.init_fixed_pairings_and_log_alphas(\n", - " fixed_pairings, device=x.device\n", + "\n", + " return fixed_pairings\n", + "\n", + " def init_diffpassresults() -> DiffPaSSResults:\n", + " return self._init_results(\n", + " record_log_alphas=single_fit_cfg[\"record_log_alphas\"],\n", + " record_soft_perms=single_fit_cfg[\"record_soft_perms\"],\n", + " record_soft_losses=single_fit_cfg[\"record_soft_losses\"],\n", " )\n", "\n", + " def extend_results_with_lowest_loss_repeat(\n", + " results_this_iter: DiffPaSSResults,\n", + " results: DiffPaSSResults,\n", + " can_optimize: bool,\n", + " ) -> None:\n", + " \"\"\"Extend the global optimization object `results` with the portion of\n", + " `results_this_iter` (from the latest bootstrap iteration) corresponding to\n", + " the repeat with the lowest hard loss.\"\"\"\n", + " if can_optimize:\n", + " # Select run with lowest hard loss, discard the rest\n", + " reshaped_hard_losses_this_repeat = np.asarray(\n", + " results_this_iter.hard_losses\n", + " ).reshape(n_repeats, -1)\n", + " min_loss_idx = np.argmin(reshaped_hard_losses_this_repeat[:, -1])\n", + " size_each_repeat = reshaped_hard_losses_this_repeat.shape[1]\n", + " # Record complete results of the run with the lowest loss\n", + " slice_to_append = slice(\n", + " min_loss_idx * size_each_repeat,\n", + " (min_loss_idx + 1) * size_each_repeat,\n", + " )\n", + " else:\n", + " slice_to_append = slice(None)\n", + " [\n", + " getattr(results, field_name).extend(\n", + " getattr(results_this_iter, field_name)[slice_to_append]\n", + " )\n", + " for field_name in available_fields\n", + " ]\n", + "\n", + " postprocess_results_after_repeats = (\n", + " extend_results_with_lowest_loss_repeat\n", + " if n_repeats > 1\n", + " else lambda *args: None\n", + " )\n", + "\n", + " ########## End closures ##########\n", + "\n", + " # Configuration for each gradient descent run\n", + " _single_fit_cfg = deepcopy(self.single_fit_default_cfg)\n", + " if single_fit_cfg is not None:\n", + " _single_fit_cfg.update(single_fit_cfg)\n", + " single_fit_cfg = _single_fit_cfg\n", + "\n", + " # Initialize DiffPaSSResults object\n", + " results = init_diffpassresults()\n", + " available_fields = [\n", + " field.name\n", + " for field in fields(results)\n", + " if getattr(results, field.name) is not None\n", + " ]\n", + " field_to_length_so_far = {field_name: 0 for field_name in available_fields}\n", + "\n", + " ########## Optimization ##########\n", + "\n", + " # First fit with initially fixed pairings\n", + " can_optimize = self._fit(x, y, results=results, **single_fit_cfg)\n", + " n_iters_with_optimization = int(can_optimize)\n", + "\n", + " # DiffPaSSResults object for each bootstrap iteration:\n", + " # new object if `n_repeats` > 1, else the existing `results`\n", + " get_results_to_use_in_each_bootstrap_iter = (\n", + " init_diffpassresults if n_repeats > 1 else lambda: results\n", + " )\n", + "\n", + " # Subsequent bootstrap fits: at a given iteration we use fixed pairings chosen uniformly at\n", + " # random from the results of the previous iteration (excluding the effective initially\n", + " # fixed pairings)\n", + " for N in pbar:\n", + " latest_hard_perms = results.hard_perms[-1]\n", + " mapped_idxs = offsets + np.concatenate(latest_hard_perms)\n", + "\n", " field_to_length_so_far = {\n", " field_name: len(getattr(results, field_name))\n", " for field_name in available_fields\n", " }\n", - " can_optimize = self._fit(x, y, results=results, **single_fit_cfg)\n", + "\n", + " results_this_iter = (\n", + " get_results_to_use_in_each_bootstrap_iter()\n", + " ) # `results` alias if `n_repeats` == 1\n", + " for _ in range(n_repeats):\n", + " # Randomly sample N fixed pairings\n", + " fixed_pairings = make_new_fixed_pairings(mapped_idxs, N)\n", + " # Reinitialize permutation module with new fixed pairings\n", + " self.permutation.init_fixed_pairings_and_log_alphas(\n", + " fixed_pairings, device=x.device\n", + " )\n", + " # Fit with gradient descent\n", + " can_optimize = self._fit(\n", + " x, y, results=results_this_iter, **single_fit_cfg\n", + " )\n", + " if not can_optimize:\n", + " # If we can't fit, we break the \"repeats\" loop\n", + " break\n", + "\n", + " postprocess_results_after_repeats(\n", + " results_this_iter, results, can_optimize\n", + " ) # Does nothing if `n_repeats` == 1\n", + "\n", " if can_optimize:\n", " n_iters_with_optimization += 1\n", " else:\n", + " # If we could not fit, terminate the bootstrap\n", " break\n", "\n", + " ########## End optimization ##########\n", + "\n", + " ########## Post-processing ##########\n", + "\n", " # Reshape results according to number of iterations performed\n", " reshaped_fields = {}\n", " for field_name in available_fields:\n", @@ -712,6 +805,8 @@ " )\n", " results = replace(results, **reshaped_fields)\n", "\n", + " ########## End post-processing ##########\n", + "\n", " return results" ] }, diff --git a/nbs/model.ipynb b/nbs/model.ipynb index 929cd9e..141fae4 100644 --- a/nbs/model.ipynb +++ b/nbs/model.ipynb @@ -271,7 +271,7 @@ " self.mode = \"hard\"\n", "\n", " def _impl_fixed_pairings(self, func: callable) -> callable:\n", - " \"\"\"Include fixed matchings in the Gumbel-Sinkhorn or Gumbel-matching operators.\"\"\"\n", + " \"\"\"Include fixed pairings in the Gumbel-Sinkhorn or Gumbel-matching operators.\"\"\"\n", "\n", " def wrapper(gen: Iterator[torch.Tensor]) -> Iterator[torch.Tensor]:\n", " for s, mat, (row_group, col_group), mask in zip(\n",