From 51fe0bdaf8b7939e7706da7508c3ef6f01e4d798 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Fri, 22 Mar 2024 11:59:19 +0100 Subject: [PATCH 1/6] Added functionality to log in which epoch the model was selected to the results file. Integrated in the model selection process --- domainlab/algos/msels/a_model_sel.py | 22 +++++++++++++++++++++- domainlab/algos/msels/c_msel_oracle.py | 6 +++--- domainlab/algos/msels/c_msel_tr_loss.py | 2 +- domainlab/algos/msels/c_msel_val.py | 4 ++-- domainlab/algos/observers/b_obvisitor.py | 4 +++- domainlab/exp/exp_utils.py | 1 + 6 files changed, 31 insertions(+), 8 deletions(-) diff --git a/domainlab/algos/msels/a_model_sel.py b/domainlab/algos/msels/a_model_sel.py index 0902733ab..a177ebd1d 100644 --- a/domainlab/algos/msels/a_model_sel.py +++ b/domainlab/algos/msels/a_model_sel.py @@ -18,6 +18,7 @@ def __init__(self): self._observer = None self.msel = None self._max_es = None + self._model_selection_epoch = None def reset(self): """ @@ -54,8 +55,18 @@ def accept(self, trainer, observer4msel): if self.msel is not None: self.msel.accept(trainer, observer4msel) + def update(self, epoch, clear_counter=False): + """ + level above the observer + visitor pattern to get information about the epoch + """ + update = self._update(clear_counter) + if update: + self._model_selection_epoch = epoch + + return update + @abc.abstractmethod - def update(self, clear_counter=False): + def _update(self, clear_counter=False): """ observer + visitor pattern to trainer if the best model should be updated @@ -101,3 +112,12 @@ def sel_model_te_acc(self): if self.msel is not None: return self.msel.sel_model_te_acc return -1 + + @property + def model_selection_epoch(self): + """ + the epoch when the model was selected + """ + if self._model_selection_epoch is not None: + return self._model_selection_epoch + return -1 \ No newline at end of file diff --git a/domainlab/algos/msels/c_msel_oracle.py b/domainlab/algos/msels/c_msel_oracle.py index 95be69f38..21d51fc54 100644 --- a/domainlab/algos/msels/c_msel_oracle.py +++ b/domainlab/algos/msels/c_msel_oracle.py @@ -30,14 +30,14 @@ def oracle_last_setpoint_sel_te_acc(self): return self.msel.oracle_last_setpoint_sel_te_acc return -1 - def update(self, clear_counter=False): + def _update(self, clear_counter=False): """ if the best model should be updated """ self.trainer.model.save("epoch") flag = False if self.observer4msel.metric_val is None: - return super().update(clear_counter) + return super()._update(clear_counter) metric = self.observer4msel.metric_te[self.observer4msel.str_metric4msel] if metric > self.best_oracle_acc: self.best_oracle_acc = metric @@ -49,7 +49,7 @@ def update(self, clear_counter=False): logger.info("new oracle model saved") flag = True if self.msel is not None: - return self.msel.update(clear_counter) + return self.msel._update(clear_counter) return flag def if_stop(self): diff --git a/domainlab/algos/msels/c_msel_tr_loss.py b/domainlab/algos/msels/c_msel_tr_loss.py index 58ca03e21..13e3e0085 100644 --- a/domainlab/algos/msels/c_msel_tr_loss.py +++ b/domainlab/algos/msels/c_msel_tr_loss.py @@ -28,7 +28,7 @@ def reset(self): def max_es(self): return self._max_es - def update(self, clear_counter=False): + def _update(self, clear_counter=False): """ if the best model should be updated """ diff --git a/domainlab/algos/msels/c_msel_val.py b/domainlab/algos/msels/c_msel_val.py index 054ca56ad..7f25a782c 100644 --- a/domainlab/algos/msels/c_msel_val.py +++ b/domainlab/algos/msels/c_msel_val.py @@ -39,13 +39,13 @@ def best_te_metric(self): """ return self._best_te_metric - def update(self, clear_counter=False): + def _update(self, clear_counter=False): """ if the best model should be updated """ flag = True if self.observer4msel.metric_val is None: - return super().update(clear_counter) + return super()._update(clear_counter) metric = self.observer4msel.metric_val[self.observer4msel.str_metric4msel] if self.observer4msel.metric_te is not None: metric_te_current = self.observer4msel.metric_te[self.observer4msel.str_metric4msel] diff --git a/domainlab/algos/observers/b_obvisitor.py b/domainlab/algos/observers/b_obvisitor.py index 04231a917..68ef82632 100644 --- a/domainlab/algos/observers/b_obvisitor.py +++ b/domainlab/algos/observers/b_obvisitor.py @@ -53,7 +53,7 @@ def update(self, epoch): self.loader_te, self.device ) self.metric_te = metric_te - if self.model_sel.update(): + if self.model_sel.update(epoch): logger.info("better model found") self.host_trainer.model.save() logger.info("persisted") @@ -102,8 +102,10 @@ def after_all(self): metric_te.update({"acc_oracle": -1}) if hasattr(self, "model_sel"): metric_te.update({"acc_val": self.model_sel.best_val_acc}) + metric_te.update({"model_selection_epoch": self.model_sel.model_selection_epoch}) else: metric_te.update({"acc_val": -1}) + metric_te.update({"model_selection_epoch": -1}) self.dump_prediction(model_ld, metric_te) # save metric to one line in csv result file self.host_trainer.model.visitor(metric_te) diff --git a/domainlab/exp/exp_utils.py b/domainlab/exp/exp_utils.py index 2af681731..d467285d6 100644 --- a/domainlab/exp/exp_utils.py +++ b/domainlab/exp/exp_utils.py @@ -169,6 +169,7 @@ def get_cols(self): # algorithm configuration for instance "mname": "mname_" + self.model_name, "commit": "commit_" + self.git_tag, + "model_selection_epoch" : None } return dict_cols, epos_name From 2c7901d40f9295865c9b10da7e1a0aa7c0fa49e9 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Fri, 22 Mar 2024 17:18:11 +0100 Subject: [PATCH 2/6] removed unnesseary line for the dictionary columns, because it is already contained in the dictionary beeing passed to the writer --- domainlab/exp/exp_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/domainlab/exp/exp_utils.py b/domainlab/exp/exp_utils.py index d467285d6..2af681731 100644 --- a/domainlab/exp/exp_utils.py +++ b/domainlab/exp/exp_utils.py @@ -169,7 +169,6 @@ def get_cols(self): # algorithm configuration for instance "mname": "mname_" + self.model_name, "commit": "commit_" + self.git_tag, - "model_selection_epoch" : None } return dict_cols, epos_name From 82d6b5cdcad164ccdd0447d01e01f83ea1fbae53 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Fri, 22 Mar 2024 18:00:43 +0100 Subject: [PATCH 3/6] renamed method to not conflict with protected declaration --- domainlab/algos/msels/a_model_sel.py | 4 ++-- domainlab/algos/msels/c_msel_oracle.py | 6 +++--- domainlab/algos/msels/c_msel_tr_loss.py | 2 +- domainlab/algos/msels/c_msel_val.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/domainlab/algos/msels/a_model_sel.py b/domainlab/algos/msels/a_model_sel.py index a177ebd1d..1830ff5fd 100644 --- a/domainlab/algos/msels/a_model_sel.py +++ b/domainlab/algos/msels/a_model_sel.py @@ -59,14 +59,14 @@ def update(self, epoch, clear_counter=False): """ level above the observer + visitor pattern to get information about the epoch """ - update = self._update(clear_counter) + update = self.base_update(clear_counter) if update: self._model_selection_epoch = epoch return update @abc.abstractmethod - def _update(self, clear_counter=False): + def base_update(self, clear_counter=False): """ observer + visitor pattern to trainer if the best model should be updated diff --git a/domainlab/algos/msels/c_msel_oracle.py b/domainlab/algos/msels/c_msel_oracle.py index 21d51fc54..35943c295 100644 --- a/domainlab/algos/msels/c_msel_oracle.py +++ b/domainlab/algos/msels/c_msel_oracle.py @@ -30,14 +30,14 @@ def oracle_last_setpoint_sel_te_acc(self): return self.msel.oracle_last_setpoint_sel_te_acc return -1 - def _update(self, clear_counter=False): + def base_update(self, clear_counter=False): """ if the best model should be updated """ self.trainer.model.save("epoch") flag = False if self.observer4msel.metric_val is None: - return super()._update(clear_counter) + return super().base_update(clear_counter) metric = self.observer4msel.metric_te[self.observer4msel.str_metric4msel] if metric > self.best_oracle_acc: self.best_oracle_acc = metric @@ -49,7 +49,7 @@ def _update(self, clear_counter=False): logger.info("new oracle model saved") flag = True if self.msel is not None: - return self.msel._update(clear_counter) + return self.msel.base_update(clear_counter) return flag def if_stop(self): diff --git a/domainlab/algos/msels/c_msel_tr_loss.py b/domainlab/algos/msels/c_msel_tr_loss.py index 13e3e0085..9c4ee2c54 100644 --- a/domainlab/algos/msels/c_msel_tr_loss.py +++ b/domainlab/algos/msels/c_msel_tr_loss.py @@ -28,7 +28,7 @@ def reset(self): def max_es(self): return self._max_es - def _update(self, clear_counter=False): + def base_update(self, clear_counter=False): """ if the best model should be updated """ diff --git a/domainlab/algos/msels/c_msel_val.py b/domainlab/algos/msels/c_msel_val.py index 7f25a782c..581f3a9b6 100644 --- a/domainlab/algos/msels/c_msel_val.py +++ b/domainlab/algos/msels/c_msel_val.py @@ -39,13 +39,13 @@ def best_te_metric(self): """ return self._best_te_metric - def _update(self, clear_counter=False): + def base_update(self, clear_counter=False): """ if the best model should be updated """ flag = True if self.observer4msel.metric_val is None: - return super()._update(clear_counter) + return super().base_update(clear_counter) metric = self.observer4msel.metric_val[self.observer4msel.str_metric4msel] if self.observer4msel.metric_te is not None: metric_te_current = self.observer4msel.metric_te[self.observer4msel.str_metric4msel] From 1da5f885c242752118cb831d5da4d7d3d3da8bd2 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Fri, 22 Mar 2024 18:21:12 +0100 Subject: [PATCH 4/6] adjusted test to fit new signature --- tests/test_msel_oracle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_msel_oracle.py b/tests/test_msel_oracle.py index 3dfdd6fbe..8819d59f0 100644 --- a/tests/test_msel_oracle.py +++ b/tests/test_msel_oracle.py @@ -142,7 +142,7 @@ def test_msel_oracle1(): ) exp.execute(num_epochs=2) - exp.trainer.observer.model_sel.msel.update(clear_counter=True) + exp.trainer.observer.model_sel.msel.update(epoch=1, clear_counter=True) del exp @@ -261,5 +261,5 @@ def test_msel_oracle4(): ) exp.execute(num_epochs=2) exp.trainer.observer.model_sel.msel.best_loss = 0 - exp.trainer.observer.model_sel.msel.update(clear_counter=True) + exp.trainer.observer.model_sel.msel.update(epoch = 1, clear_counter=True) del exp From d5c03c891a84c824d87019500e61cfbfa44a0c91 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Mon, 25 Mar 2024 14:01:34 +0100 Subject: [PATCH 5/6] Added epoch to update to fix method --- tests/test_msel_oracle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_msel_oracle.py b/tests/test_msel_oracle.py index 8819d59f0..8a28567b0 100644 --- a/tests/test_msel_oracle.py +++ b/tests/test_msel_oracle.py @@ -66,7 +66,7 @@ def mk_exp( model_sel.msel._best_val_acc = 1.0 observer = ObVisitor(model_sel) exp = Exp(conf, task, model=model, observer=observer) - model_sel.update(clear_counter=True) + model_sel.update(epoch=1, clear_counter=True) return exp From 1ec8891d493ec80808931f6a7008f462bb3b1aeb Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Wed, 27 Mar 2024 12:53:33 +0100 Subject: [PATCH 6/6] codacity fix --- domainlab/algos/msels/a_model_sel.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/domainlab/algos/msels/a_model_sel.py b/domainlab/algos/msels/a_model_sel.py index 1830ff5fd..1f56811ce 100644 --- a/domainlab/algos/msels/a_model_sel.py +++ b/domainlab/algos/msels/a_model_sel.py @@ -60,13 +60,13 @@ def update(self, epoch, clear_counter=False): level above the observer + visitor pattern to get information about the epoch """ update = self.base_update(clear_counter) - if update: + if update: self._model_selection_epoch = epoch return update @abc.abstractmethod - def base_update(self, clear_counter=False): + def base_update(self, clear_counter=False): """ observer + visitor pattern to trainer if the best model should be updated @@ -112,7 +112,7 @@ def sel_model_te_acc(self): if self.msel is not None: return self.msel.sel_model_te_acc return -1 - + @property def model_selection_epoch(self): """ @@ -120,4 +120,4 @@ def model_selection_epoch(self): """ if self._model_selection_epoch is not None: return self._model_selection_epoch - return -1 \ No newline at end of file + return -1