Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 143 additions & 48 deletions diffpass/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ def fit_bootstrap(
int
] = None, # If ``None``, the bootstrap will end when all pairs are fixed. Otherwise, the bootstrap will end when `n_end` pairs are fixed
step_size: int = 1, # Difference between the number of fixed pairings chosen at consecutive bootstrap iterations
n_repeats: int = 1, # At each bootstrap iteration, `n_repeats` runs will be performed, and the run with the lowest loss will be chosen
show_pbar: bool = True, # If ``True``, show progress bar. Default: ``True``
single_fit_cfg: Optional[
dict
Expand All @@ -528,95 +529,187 @@ def fit_bootstrap(
and fixed for the next run.
The number of pairings fixed at each iteration ranges between `n_start` (default: 1) and `n_end` (default: total number of pairs), with a step size of `step_size`.
"""
self.prepare_fit(x, y)
if self.fixed_pairings is None:
initial_fixed_pairings = [[] for _ in self.group_sizes]
else:
initial_fixed_pairings = [list(fm) for fm in self.fixed_pairings]

_single_fit_cfg = deepcopy(self.single_fit_default_cfg)
if single_fit_cfg is not None:
_single_fit_cfg.update(single_fit_cfg)
single_fit_cfg = _single_fit_cfg
########## Preparations ##########

# Initialize DiffPaSSResults object
results = self._init_results(
record_log_alphas=single_fit_cfg["record_log_alphas"],
record_soft_perms=single_fit_cfg["record_soft_perms"],
record_soft_losses=single_fit_cfg["record_soft_losses"],
)
available_fields = [
field.name
for field in fields(results)
if getattr(results, field.name) is not None
]
field_to_length_so_far = {field_name: 0 for field_name in available_fields}
# Input validation
self.prepare_fit(x, y)

# Prepare variables for indexing
n_samples = len(x)
n_groups = len(self.group_sizes)
cumsum_group_sizes = np.cumsum([0] + list(self.group_sizes))
offsets = np.repeat(cumsum_group_sizes[:-1], repeats=self.group_sizes)
group_idxs = np.repeat(np.arange(n_groups), repeats=self.group_sizes)

# First fit with initial fixed matchings
can_optimize = self._fit(x, y, results=results, **single_fit_cfg)
# Initially fixed pairings as derived from the `fixed_pairings` attribute
if self.fixed_pairings is None:
initially_fixed_pairings = [[] for _ in self.group_sizes]
else:
initially_fixed_pairings = [list(fm) for fm in self.fixed_pairings]

# Find effective initial fixed matchings
effective_initial_fixed_idxs = []
# *Effective* initially fixed pairings as global indices (not relative to group)
# Used to exclude these pairs from the random sampling of new fixed pairings
# and to determine when the bootstrap will end
effective_initially_fixed_idxs = []
for s, efmz in zip(
cumsum_group_sizes, self.permutation._effective_fixed_pairings_zip
):
if efmz:
effective_initial_fixed_idxs += [
(s + efmz_fixed) for efmz_fixed in efmz[1]
effective_initially_fixed_idxs += [
s + efmz_fixed for efmz_fixed in efmz[1]
]
effective_initial_fixed_idxs = np.asarray(effective_initial_fixed_idxs)
nonfixed_idxs = np.setdiff1d(np.arange(n_samples), effective_initial_fixed_idxs)
n_effective_initial_fixed_pairings = len(effective_initial_fixed_idxs)

non_initially_fixed_idxs = np.setdiff1d(
np.arange(n_samples), effective_initially_fixed_idxs
)
if n_end is None:
n_end = n_samples - n_effective_initial_fixed_pairings - 1

# Subsequent fits: at a given iteration we use fixed matchings chosen uniformly at
# random from the results of the previous iteration (excluding the effective initial
# fixed matchings)
pbar = list(range(n_start, n_end, step_size))
n_end = n_samples - len(effective_initially_fixed_idxs) - 1
# Bootstrap range and progress bar
pbar = range(n_start, n_end, step_size)
pbar = tqdm(pbar) if show_pbar else pbar
n_iters_with_optimization = int(can_optimize)
for N in pbar:
latest_hard_perms = results.hard_perms[-1]
mapped_idxs = offsets + np.concatenate(latest_hard_perms)
rand_fixed_idxs = np.random.permutation(nonfixed_idxs)[:N]

########## End preparations ##########

########## Closures ##########

def make_new_fixed_pairings(
mapped_idxs: np.ndarray, N: int
) -> IndexPairsInGroups:
"""Subroutine for randomly sampling new fixed pairings for the next bootstrap iteration."""
rand_fixed_idxs = np.random.permutation(non_initially_fixed_idxs)[:N]
rand_fixed_idxs = np.sort(rand_fixed_idxs)
rand_mapped_idxs = mapped_idxs[rand_fixed_idxs]
rand_group_idxs = group_idxs[rand_fixed_idxs]
rand_fixed_rel_idxs = rand_fixed_idxs - offsets[rand_fixed_idxs]
rand_mapped_rel_idxs = rand_mapped_idxs - offsets[rand_mapped_idxs]

# Update fixed matchings
# Update fixed pairings
fixed_pairings = [[] for _ in range(n_groups)]
for rand_group_idx, mapped_rel_idx, fixed_rel_idx in zip(
rand_group_idxs, rand_mapped_rel_idxs, rand_fixed_rel_idxs
):
pair = (mapped_rel_idx, fixed_rel_idx)
fixed_pairings[rand_group_idx].append(pair)
fixed_pairings = [
initial_fixed_pairings[k] + fixed_pairings[k] for k in range(n_groups)
initially_fixed_pairings[k] + fixed_pairings[k] for k in range(n_groups)
]
self.permutation.init_fixed_pairings_and_log_alphas(
fixed_pairings, device=x.device

return fixed_pairings

def init_diffpassresults() -> DiffPaSSResults:
return self._init_results(
record_log_alphas=single_fit_cfg["record_log_alphas"],
record_soft_perms=single_fit_cfg["record_soft_perms"],
record_soft_losses=single_fit_cfg["record_soft_losses"],
)

def extend_results_with_lowest_loss_repeat(
results_this_iter: DiffPaSSResults,
results: DiffPaSSResults,
can_optimize: bool,
) -> None:
"""Extend the global optimization object `results` with the portion of
`results_this_iter` (from the latest bootstrap iteration) corresponding to
the repeat with the lowest hard loss."""
if can_optimize:
# Select run with lowest hard loss, discard the rest
reshaped_hard_losses_this_repeat = np.asarray(
results_this_iter.hard_losses
).reshape(n_repeats, -1)
min_loss_idx = np.argmin(reshaped_hard_losses_this_repeat[:, -1])
size_each_repeat = reshaped_hard_losses_this_repeat.shape[1]
# Record complete results of the run with the lowest loss
slice_to_append = slice(
min_loss_idx * size_each_repeat,
(min_loss_idx + 1) * size_each_repeat,
)
else:
slice_to_append = slice(None)
[
getattr(results, field_name).extend(
getattr(results_this_iter, field_name)[slice_to_append]
)
for field_name in available_fields
]

postprocess_results_after_repeats = (
extend_results_with_lowest_loss_repeat
if n_repeats > 1
else lambda *args: None
)

########## End closures ##########

# Configuration for each gradient descent run
_single_fit_cfg = deepcopy(self.single_fit_default_cfg)
if single_fit_cfg is not None:
_single_fit_cfg.update(single_fit_cfg)
single_fit_cfg = _single_fit_cfg

# Initialize DiffPaSSResults object
results = init_diffpassresults()
available_fields = [
field.name
for field in fields(results)
if getattr(results, field.name) is not None
]
field_to_length_so_far = {field_name: 0 for field_name in available_fields}

########## Optimization ##########

# First fit with initially fixed pairings
can_optimize = self._fit(x, y, results=results, **single_fit_cfg)
n_iters_with_optimization = int(can_optimize)

# DiffPaSSResults object for each bootstrap iteration:
# new object if `n_repeats` > 1, else the existing `results`
get_results_to_use_in_each_bootstrap_iter = (
init_diffpassresults if n_repeats > 1 else lambda: results
)

# Subsequent bootstrap fits: at a given iteration we use fixed pairings chosen uniformly at
# random from the results of the previous iteration (excluding the effective initially
# fixed pairings)
for N in pbar:
latest_hard_perms = results.hard_perms[-1]
mapped_idxs = offsets + np.concatenate(latest_hard_perms)

field_to_length_so_far = {
field_name: len(getattr(results, field_name))
for field_name in available_fields
}
can_optimize = self._fit(x, y, results=results, **single_fit_cfg)

results_this_iter = (
get_results_to_use_in_each_bootstrap_iter()
) # `results` alias if `n_repeats` == 1
for _ in range(n_repeats):
# Randomly sample N fixed pairings
fixed_pairings = make_new_fixed_pairings(mapped_idxs, N)
# Reinitialize permutation module with new fixed pairings
self.permutation.init_fixed_pairings_and_log_alphas(
fixed_pairings, device=x.device
)
# Fit with gradient descent
can_optimize = self._fit(
x, y, results=results_this_iter, **single_fit_cfg
)
if not can_optimize:
# If we can't fit, we break the "repeats" loop
break

postprocess_results_after_repeats(
results_this_iter, results, can_optimize
) # Does nothing if `n_repeats` == 1

if can_optimize:
n_iters_with_optimization += 1
else:
# If we could not fit, terminate the bootstrap
break

########## End optimization ##########

########## Post-processing ##########

# Reshape results according to number of iterations performed
reshaped_fields = {}
for field_name in available_fields:
Expand Down Expand Up @@ -644,4 +737,6 @@ def fit_bootstrap(
)
results = replace(results, **reshaped_fields)

########## End post-processing ##########

return results
2 changes: 1 addition & 1 deletion diffpass/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def hard_(self) -> None:
self.mode = "hard"

def _impl_fixed_pairings(self, func: callable) -> callable:
"""Include fixed matchings in the Gumbel-Sinkhorn or Gumbel-matching operators."""
"""Include fixed pairings in the Gumbel-Sinkhorn or Gumbel-matching operators."""

def wrapper(gen: Iterator[torch.Tensor]) -> Iterator[torch.Tensor]:
for s, mat, (row_group, col_group), mask in zip(
Expand Down
Loading