diff --git a/domainlab/algos/msels/a_model_sel.py b/domainlab/algos/msels/a_model_sel.py index 0902733ab..1f56811ce 100644 --- a/domainlab/algos/msels/a_model_sel.py +++ b/domainlab/algos/msels/a_model_sel.py @@ -18,6 +18,7 @@ def __init__(self): self._observer = None self.msel = None self._max_es = None + self._model_selection_epoch = None def reset(self): """ @@ -54,8 +55,18 @@ def accept(self, trainer, observer4msel): if self.msel is not None: self.msel.accept(trainer, observer4msel) + def update(self, epoch, clear_counter=False): + """ + level above the observer + visitor pattern to get information about the epoch + """ + update = self.base_update(clear_counter) + if update: + self._model_selection_epoch = epoch + + return update + @abc.abstractmethod - def update(self, clear_counter=False): + def base_update(self, clear_counter=False): """ observer + visitor pattern to trainer if the best model should be updated @@ -101,3 +112,12 @@ def sel_model_te_acc(self): if self.msel is not None: return self.msel.sel_model_te_acc return -1 + + @property + def model_selection_epoch(self): + """ + the epoch when the model was selected + """ + if self._model_selection_epoch is not None: + return self._model_selection_epoch + return -1 diff --git a/domainlab/algos/msels/c_msel_oracle.py b/domainlab/algos/msels/c_msel_oracle.py index 95be69f38..35943c295 100644 --- a/domainlab/algos/msels/c_msel_oracle.py +++ b/domainlab/algos/msels/c_msel_oracle.py @@ -30,14 +30,14 @@ def oracle_last_setpoint_sel_te_acc(self): return self.msel.oracle_last_setpoint_sel_te_acc return -1 - def update(self, clear_counter=False): + def base_update(self, clear_counter=False): """ if the best model should be updated """ self.trainer.model.save("epoch") flag = False if self.observer4msel.metric_val is None: - return super().update(clear_counter) + return super().base_update(clear_counter) metric = self.observer4msel.metric_te[self.observer4msel.str_metric4msel] if metric > self.best_oracle_acc: self.best_oracle_acc = metric @@ -49,7 +49,7 @@ def update(self, clear_counter=False): logger.info("new oracle model saved") flag = True if self.msel is not None: - return self.msel.update(clear_counter) + return self.msel.base_update(clear_counter) return flag def if_stop(self): diff --git a/domainlab/algos/msels/c_msel_tr_loss.py b/domainlab/algos/msels/c_msel_tr_loss.py index 58ca03e21..9c4ee2c54 100644 --- a/domainlab/algos/msels/c_msel_tr_loss.py +++ b/domainlab/algos/msels/c_msel_tr_loss.py @@ -28,7 +28,7 @@ def reset(self): def max_es(self): return self._max_es - def update(self, clear_counter=False): + def base_update(self, clear_counter=False): """ if the best model should be updated """ diff --git a/domainlab/algos/msels/c_msel_val.py b/domainlab/algos/msels/c_msel_val.py index 054ca56ad..581f3a9b6 100644 --- a/domainlab/algos/msels/c_msel_val.py +++ b/domainlab/algos/msels/c_msel_val.py @@ -39,13 +39,13 @@ def best_te_metric(self): """ return self._best_te_metric - def update(self, clear_counter=False): + def base_update(self, clear_counter=False): """ if the best model should be updated """ flag = True if self.observer4msel.metric_val is None: - return super().update(clear_counter) + return super().base_update(clear_counter) metric = self.observer4msel.metric_val[self.observer4msel.str_metric4msel] if self.observer4msel.metric_te is not None: metric_te_current = self.observer4msel.metric_te[self.observer4msel.str_metric4msel] diff --git a/domainlab/algos/observers/b_obvisitor.py b/domainlab/algos/observers/b_obvisitor.py index 04231a917..68ef82632 100644 --- a/domainlab/algos/observers/b_obvisitor.py +++ b/domainlab/algos/observers/b_obvisitor.py @@ -53,7 +53,7 @@ def update(self, epoch): self.loader_te, self.device ) self.metric_te = metric_te - if self.model_sel.update(): + if self.model_sel.update(epoch): logger.info("better model found") self.host_trainer.model.save() logger.info("persisted") @@ -102,8 +102,10 @@ def after_all(self): metric_te.update({"acc_oracle": -1}) if hasattr(self, "model_sel"): metric_te.update({"acc_val": self.model_sel.best_val_acc}) + metric_te.update({"model_selection_epoch": self.model_sel.model_selection_epoch}) else: metric_te.update({"acc_val": -1}) + metric_te.update({"model_selection_epoch": -1}) self.dump_prediction(model_ld, metric_te) # save metric to one line in csv result file self.host_trainer.model.visitor(metric_te) diff --git a/tests/test_msel_oracle.py b/tests/test_msel_oracle.py index 3dfdd6fbe..8a28567b0 100644 --- a/tests/test_msel_oracle.py +++ b/tests/test_msel_oracle.py @@ -66,7 +66,7 @@ def mk_exp( model_sel.msel._best_val_acc = 1.0 observer = ObVisitor(model_sel) exp = Exp(conf, task, model=model, observer=observer) - model_sel.update(clear_counter=True) + model_sel.update(epoch=1, clear_counter=True) return exp @@ -142,7 +142,7 @@ def test_msel_oracle1(): ) exp.execute(num_epochs=2) - exp.trainer.observer.model_sel.msel.update(clear_counter=True) + exp.trainer.observer.model_sel.msel.update(epoch=1, clear_counter=True) del exp @@ -261,5 +261,5 @@ def test_msel_oracle4(): ) exp.execute(num_epochs=2) exp.trainer.observer.model_sel.msel.best_loss = 0 - exp.trainer.observer.model_sel.msel.update(clear_counter=True) + exp.trainer.observer.model_sel.msel.update(epoch = 1, clear_counter=True) del exp