From 2880421e373957e44ff38d1ff6bf751b41011da7 Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 8 Mar 2024 10:38:22 +0100 Subject: [PATCH] . --- domainlab/algos/msels/a_model_sel.py | 12 ++++++------ domainlab/algos/msels/c_msel_oracle.py | 10 +++++----- domainlab/algos/msels/c_msel_val.py | 16 ++++++++-------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/domainlab/algos/msels/a_model_sel.py b/domainlab/algos/msels/a_model_sel.py index e2e63c993..0902733ab 100644 --- a/domainlab/algos/msels/a_model_sel.py +++ b/domainlab/algos/msels/a_model_sel.py @@ -15,7 +15,7 @@ def __init__(self): trainer and tr_observer """ self.trainer = None - self._tr_obs = None + self._observer = None self.msel = None self._max_es = None @@ -27,11 +27,11 @@ def reset(self): self.msel.reset() @property - def tr_obs(self): + def observer4msel(self): """ the observer from trainer """ - return self._tr_obs + return self._observer @property def max_es(self): @@ -44,15 +44,15 @@ def max_es(self): return self.msel.max_es return self._max_es - def accept(self, trainer, tr_obs): + def accept(self, trainer, observer4msel): """ Visitor pattern to trainer accept trainer and tr_observer """ self.trainer = trainer - self._tr_obs = tr_obs + self._observer = observer4msel if self.msel is not None: - self.msel.accept(trainer, tr_obs) + self.msel.accept(trainer, observer4msel) @abc.abstractmethod def update(self, clear_counter=False): diff --git a/domainlab/algos/msels/c_msel_oracle.py b/domainlab/algos/msels/c_msel_oracle.py index e232b1e78..95be69f38 100644 --- a/domainlab/algos/msels/c_msel_oracle.py +++ b/domainlab/algos/msels/c_msel_oracle.py @@ -36,9 +36,9 @@ def update(self, clear_counter=False): """ self.trainer.model.save("epoch") flag = False - if self.tr_obs.metric_val is None: + if self.observer4msel.metric_val is None: return super().update(clear_counter) - metric = self.tr_obs.metric_te[self.tr_obs.str_metric4msel] + metric = self.observer4msel.metric_te[self.observer4msel.str_metric4msel] if metric > self.best_oracle_acc: self.best_oracle_acc = metric if self.msel is not None: @@ -62,7 +62,7 @@ def if_stop(self): return self.msel.if_stop() return False - def accept(self, trainer, tr_obs): + def accept(self, trainer, observer4msel): if self.msel is not None: - self.msel.accept(trainer, tr_obs) - super().accept(trainer, tr_obs) + self.msel.accept(trainer, observer4msel) + super().accept(trainer, observer4msel) diff --git a/domainlab/algos/msels/c_msel_val.py b/domainlab/algos/msels/c_msel_val.py index c1f2f5561..054ca56ad 100644 --- a/domainlab/algos/msels/c_msel_val.py +++ b/domainlab/algos/msels/c_msel_val.py @@ -12,7 +12,7 @@ class MSelValPerf(MSelTrLoss): """ def __init__(self, max_es): - super().__init__(max_es) # construct self.tr_obs (observer) + super().__init__(max_es) # construct self.observer4msel (observer) self.reset() def reset(self): @@ -44,11 +44,11 @@ def update(self, clear_counter=False): if the best model should be updated """ flag = True - if self.tr_obs.metric_val is None: + if self.observer4msel.metric_val is None: return super().update(clear_counter) - metric = self.tr_obs.metric_val[self.tr_obs.str_metric4msel] - if self.tr_obs.metric_te is not None: - metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel] + metric = self.observer4msel.metric_val[self.observer4msel.str_metric4msel] + if self.observer4msel.metric_te is not None: + metric_te_current = self.observer4msel.metric_te[self.observer4msel.str_metric4msel] self._best_te_metric = max(self._best_te_metric, metric_te_current) if metric > self._best_val_acc: # update hat{model} @@ -56,8 +56,8 @@ def update(self, clear_counter=False): # the bigger the better self._best_val_acc = metric self.es_c = 0 # restore counter - if self.tr_obs.metric_te is not None: - metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel] + if self.observer4msel.metric_te is not None: + metric_te_current = self.observer4msel.metric_te[self.observer4msel.str_metric4msel] self._sel_model_te_acc = metric_te_current else: @@ -65,7 +65,7 @@ def update(self, clear_counter=False): logger = Logger.get_logger() logger.info(f"early stop counter: {self.es_c}") logger.info( - f"val acc:{self.tr_obs.metric_val['acc']}, " + f"val acc:{self.observer4msel.metric_val['acc']}, " + f"best validation acc: {self.best_val_acc}, " + f"corresponding to test acc: \ {self.sel_model_te_acc} / {self.best_te_metric}"