From accd5ccc146dbd308b40dda02d18168866f7ba02 Mon Sep 17 00:00:00 2001 From: juacrumar Date: Thu, 25 Jul 2024 14:05:57 +0200 Subject: [PATCH 1/2] disable jit compilation in tf > 2.16 --- n3fit/src/n3fit/backends/keras_backend/MetaModel.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/n3fit/src/n3fit/backends/keras_backend/MetaModel.py b/n3fit/src/n3fit/backends/keras_backend/MetaModel.py index d72208761e..2399c16a64 100644 --- a/n3fit/src/n3fit/backends/keras_backend/MetaModel.py +++ b/n3fit/src/n3fit/backends/keras_backend/MetaModel.py @@ -25,6 +25,10 @@ else: # in case of disaster _to_numpy_or_python_type = lambda ret: {k: i.numpy() for k, i in ret.items()} +# Starting with TF 2.16, a memory leak in TF https://github.com/tensorflow/tensorflow/issues/64170 +# makes jit compilation unusable in GPU. +# Before TF 2.16 it was set to `False` by default. From 2.16 onwards, it is set to `True` +JIT_COMPILE = False # Define in this dictionary new optimizers as well as the arguments they accept # (with default values if needed be) @@ -307,7 +311,7 @@ def compile( target_output = [target_output] self.target_tensors = target_output - super().compile(optimizer=opt, loss=loss) + super().compile(optimizer=opt, loss=loss, jit_compile=JIT_COMPILE) def set_masks_to(self, names, val=0.0): """Set all mask value to the selected value From 2120355f1b2c831054f518bd1a87ff5eaa16f1a5 Mon Sep 17 00:00:00 2001 From: juacrumar Date: Fri, 26 Jul 2024 21:40:39 +0200 Subject: [PATCH 2/2] fix bug in stopping --- n3fit/src/n3fit/stopping.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/n3fit/src/n3fit/stopping.py b/n3fit/src/n3fit/stopping.py index e2f11a2579..48194cdf1b 100644 --- a/n3fit/src/n3fit/stopping.py +++ b/n3fit/src/n3fit/stopping.py @@ -27,6 +27,7 @@ which will tell `Validation` that no validation set was found and that the training is to be used instead. """ + import logging import numpy as np @@ -345,6 +346,8 @@ def __init__( self._threshold_chi2 = threshold_chi2 self._stopping_degrees = np.zeros(self._n_replicas, dtype=int) self._counts = np.zeros(self._n_replicas, dtype=int) + # Keep track of the replicas that should not be stopped yet + self._dont_stop_me_now = np.ones(self._n_replicas, dtype=bool) self._dont_stop = dont_stop self._stop_now = False @@ -451,6 +454,8 @@ def monitor_chi2(self, training_info, epoch, print_stats=False): passes &= fitstate.vl_loss < self._best_val_chi2s # And the ones that pass positivity passes &= self._positivity(fitstate) + # Stop replicas that are ok being stopped (because they are finished or otherwise) + passes &= self._dont_stop_me_now self._stopping_degrees += self._counts @@ -470,6 +475,7 @@ def monitor_chi2(self, training_info, epoch, print_stats=False): for i_replica in np.where(stop_replicas)[0]: self._stop_epochs[i_replica] = epoch self._counts[i_replica] = 0 + self._dont_stop_me_now[i_replica] = False # By using the stopping degree we only stop when none of the replicas are improving anymore if min(self._stopping_degrees) > self.stopping_patience: