Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions n3fit/src/n3fit/stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ def __init__(

self._dont_stop = dont_stop
self._stop_now = False
self._stopping_patience = stopping_patience
self._total_epochs = total_epochs
self.stopping_patience = stopping_patience
self.total_epochs = total_epochs

self._stop_epochs = [total_epochs - 1] * self._n_replicas
self._best_epochs = [None] * self._n_replicas
Expand Down Expand Up @@ -472,13 +472,13 @@ def monitor_chi2(self, training_info, epoch, print_stats=False):
self._stopping_degrees[i_replica] = 0
self._counts[i_replica] = 1

stop_replicas = self._counts & (self._stopping_degrees > self._stopping_patience)
stop_replicas = self._counts & (self._stopping_degrees > self.stopping_patience)
for i_replica in np.where(stop_replicas)[0]:
self._stop_epochs[i_replica] = epoch
self._counts[i_replica] = 0

# By using the stopping degree we only stop when none of the replicas are improving anymore
if min(self._stopping_degrees) > self._stopping_patience:
if min(self._stopping_degrees) > self.stopping_patience:
self.make_stop()
return True

Expand All @@ -501,7 +501,7 @@ def print_current_stats(self, epoch, fitstate):
epoch_index = epoch + 1
tr_chi2 = fitstate.total_tr_chi2()
vl_chi2 = fitstate.total_vl_chi2()
total_str = f"At epoch {epoch_index}/{self._total_epochs}, total chi2: {tr_chi2}\n"
total_str = f"At epoch {epoch_index}/{self.total_epochs}, total chi2: {tr_chi2}\n"

# The partial chi2 makes no sense for more than one replica at once:
if self._n_replicas == 1:
Expand Down