Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions domainlab/algos/msels/a_model_sel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self):
trainer and tr_observer
"""
self.trainer = None
self._tr_obs = None
self._observer = None
self.msel = None
self._max_es = None

Expand All @@ -27,11 +27,11 @@ def reset(self):
self.msel.reset()

@property
def tr_obs(self):
def observer4msel(self):
"""
the observer from trainer
"""
return self._tr_obs
return self._observer

@property
def max_es(self):
Expand All @@ -44,15 +44,15 @@ def max_es(self):
return self.msel.max_es
return self._max_es

def accept(self, trainer, tr_obs):
def accept(self, trainer, observer4msel):
"""
Visitor pattern to trainer
accept trainer and tr_observer
"""
self.trainer = trainer
self._tr_obs = tr_obs
self._observer = observer4msel
if self.msel is not None:
self.msel.accept(trainer, tr_obs)
self.msel.accept(trainer, observer4msel)

@abc.abstractmethod
def update(self, clear_counter=False):
Expand Down
10 changes: 5 additions & 5 deletions domainlab/algos/msels/c_msel_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def update(self, clear_counter=False):
"""
self.trainer.model.save("epoch")
flag = False
if self.tr_obs.metric_val is None:
if self.observer4msel.metric_val is None:
return super().update(clear_counter)
metric = self.tr_obs.metric_te[self.tr_obs.str_metric4msel]
metric = self.observer4msel.metric_te[self.observer4msel.str_metric4msel]
if metric > self.best_oracle_acc:
self.best_oracle_acc = metric
if self.msel is not None:
Expand All @@ -62,7 +62,7 @@ def if_stop(self):
return self.msel.if_stop()
return False

def accept(self, trainer, tr_obs):
def accept(self, trainer, observer4msel):
if self.msel is not None:
self.msel.accept(trainer, tr_obs)
super().accept(trainer, tr_obs)
self.msel.accept(trainer, observer4msel)
super().accept(trainer, observer4msel)
16 changes: 8 additions & 8 deletions domainlab/algos/msels/c_msel_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class MSelValPerf(MSelTrLoss):
"""

def __init__(self, max_es):
super().__init__(max_es) # construct self.tr_obs (observer)
super().__init__(max_es) # construct self.observer4msel (observer)
self.reset()

def reset(self):
Expand Down Expand Up @@ -44,28 +44,28 @@ def update(self, clear_counter=False):
if the best model should be updated
"""
flag = True
if self.tr_obs.metric_val is None:
if self.observer4msel.metric_val is None:
return super().update(clear_counter)
metric = self.tr_obs.metric_val[self.tr_obs.str_metric4msel]
if self.tr_obs.metric_te is not None:
metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel]
metric = self.observer4msel.metric_val[self.observer4msel.str_metric4msel]
if self.observer4msel.metric_te is not None:
metric_te_current = self.observer4msel.metric_te[self.observer4msel.str_metric4msel]
self._best_te_metric = max(self._best_te_metric, metric_te_current)

if metric > self._best_val_acc: # update hat{model}
# different from loss, accuracy should be improved:
# the bigger the better
self._best_val_acc = metric
self.es_c = 0 # restore counter
if self.tr_obs.metric_te is not None:
metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel]
if self.observer4msel.metric_te is not None:
metric_te_current = self.observer4msel.metric_te[self.observer4msel.str_metric4msel]
self._sel_model_te_acc = metric_te_current

else:
self.es_c += 1
logger = Logger.get_logger()
logger.info(f"early stop counter: {self.es_c}")
logger.info(
f"val acc:{self.tr_obs.metric_val['acc']}, "
f"val acc:{self.observer4msel.metric_val['acc']}, "
+ f"best validation acc: {self.best_val_acc}, "
+ f"corresponding to test acc: \
{self.sel_model_te_acc} / {self.best_te_metric}"
Expand Down