Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ conda activate domainlab_py39
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge
conda install torchmetrics==0.10.3
git checkout fbopt
pip install -r requirements_notorch.txt
pip install -r requirements_notorch.txt
conda install tensorboard
```

Expand All @@ -60,7 +60,7 @@ https://github.com/marrlab/DomainLab/blob/fbopt/data/script/download_pacs.py
step 2:
make a symbolic link following the example script in https://github.com/marrlab/DomainLab/blob/master/sh_pacs.sh

where `mkdir -p data/pacs` is executed under the repository directory,
where `mkdir -p data/pacs` is executed under the repository directory,

`ln -s /dir/to/yourdata/pacs/raw ./data/pacs/PACS`
will create a symbolic link under the repository directory
Expand Down
32 changes: 18 additions & 14 deletions domainlab/algos/builder_diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
Builder pattern to build different component for experiment with DIVA
"""
from domainlab.algos.a_algo_builder import NodeAlgoBuilder
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.msels.c_msel_val_top_k import MSelValPerfTopK
from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor
from domainlab.algos.msels.c_msel_setpoint_delay import MSelSetpointDelay
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.msels.c_msel_val_top_k import MSelValPerfTopK
from domainlab.algos.observers.b_obvisitor import ObVisitor
from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp
from domainlab.algos.observers.c_obvisitor_gen import ObVisitorGen
Expand Down Expand Up @@ -37,19 +37,23 @@ def init_business(self, exp):
request = RequestVAEBuilderCHW(task.isize.c, task.isize.h, task.isize.w, args)
node = VAEChainNodeGetter(request)()
task.get_list_domains_tr_te(args.tr_d, args.te_d)
model = mk_diva(str_diva_multiplier_type=args.str_diva_multiplier_type)(node,
zd_dim=args.zd_dim,
zy_dim=args.zy_dim,
zx_dim=args.zx_dim,
list_str_y=task.list_str_y,
list_d_tr=task.list_domain_tr,
gamma_d=args.gamma_d,
gamma_y=args.gamma_y,
beta_x=args.beta_x,
beta_y=args.beta_y,
beta_d=args.beta_d)
model = mk_diva(str_diva_multiplier_type=args.str_diva_multiplier_type)(
node,
zd_dim=args.zd_dim,
zy_dim=args.zy_dim,
zx_dim=args.zx_dim,
list_str_y=task.list_str_y,
list_d_tr=task.list_domain_tr,
gamma_d=args.gamma_d,
gamma_y=args.gamma_y,
beta_x=args.beta_x,
beta_y=args.beta_y,
beta_d=args.beta_d,
)
device = get_device(args)
model_sel = MSelSetpointDelay(MSelOracleVisitor(MSelValPerfTopK(max_es=args.es)))
model_sel = MSelSetpointDelay(
MSelOracleVisitor(MSelValPerfTopK(max_es=args.es))
)
if not args.gen:
observer = ObVisitor(model_sel)
else:
Expand Down
1 change: 1 addition & 0 deletions domainlab/algos/builder_fbopt_dial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class NodeAlgoBuilderFbOptDial(NodeAlgoBuilderDIVA):
"""
builder for feedback optimization for dial
"""

def init_business(self, exp):
"""
return trainer, model, observer
Expand Down
4 changes: 2 additions & 2 deletions domainlab/algos/builder_jigen1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
builder for JiGen
"""
from domainlab.algos.a_algo_builder import NodeAlgoBuilder
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.msels.c_msel_val_top_k import MSelValPerfTopK
from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor
from domainlab.algos.msels.c_msel_setpoint_delay import MSelSetpointDelay
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.msels.c_msel_val_top_k import MSelValPerfTopK
from domainlab.algos.observers.b_obvisitor import ObVisitor
from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp
from domainlab.algos.trainers.hyper_scheduler import HyperSchedulerWarmupExponential
Expand Down
6 changes: 3 additions & 3 deletions domainlab/algos/msels/a_model_sel.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def best_te_metric(self):
if self.msel is not None:
return self.msel.best_te_metric
return -1

@property
def sel_model_te_acc(self):
"""
Expand All @@ -101,9 +101,9 @@ def sel_model_te_acc(self):
if self.msel is not None:
return self.msel.sel_model_te_acc
return -1

@property
def oracle_last_setpoint_sel_te_acc(self):
if self.msel is not None:
return self.msel.oracle_last_setpoint_sel_te_acc
return -1
return -1
17 changes: 13 additions & 4 deletions domainlab/algos/msels/c_msel_setpoint_delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Multiobjective Model Selection
"""
import copy

from domainlab.algos.msels.a_model_sel import AMSel
from domainlab.utils.logger import Logger

Expand All @@ -11,27 +12,35 @@ class MSelSetpointDelay(AMSel):
1. Model selection using validation performance
2. Only update if setpoint has been decreased
"""

def __init__(self, msel):
super().__init__()
# NOTE: super() has to come first always otherwise self.msel will be overwritten to be None
self.msel = msel
self._oracle_last_setpoint_sel_te_acc = 0.0

@property
def oracle_last_setpoint_sel_te_acc(self):
"""
return the last setpoint best acc
"""
return self._oracle_last_setpoint_sel_te_acc

def update(self, clear_counter=False):
"""
if the best model should be updated
"""
logger = Logger.get_logger()
logger.info(f"setpoint selected current acc {self._oracle_last_setpoint_sel_te_acc}")
logger.info(
f"setpoint selected current acc {self._oracle_last_setpoint_sel_te_acc}"
)
if clear_counter:
logger.info("setpoint msel te acc updated from {self._oracle_last_setpoint_sel_te_acc} to {self.sel_model_te_acc}")
log_message = (
f"setpoint msel te acc updated from "
f"{self._oracle_last_setpoint_sel_te_acc} to "
f"{self.sel_model_te_acc}"
)
logger.info(log_message)
self._oracle_last_setpoint_sel_te_acc = self.sel_model_te_acc
flag = self.msel.update(clear_counter)
return flag
2 changes: 1 addition & 1 deletion domainlab/algos/msels/c_msel_tr_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def reset(self):
@property
def max_es(self):
return self._max_es

def update(self, clear_counter=False):
"""
if the best model should be updated
Expand Down
2 changes: 1 addition & 1 deletion domainlab/algos/msels/c_msel_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def reset(self):
@property
def sel_model_te_acc(self):
return self._sel_model_te_acc

@property
def best_val_acc(self):
"""
Expand Down
32 changes: 21 additions & 11 deletions domainlab/algos/msels/c_msel_val_top_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class MSelValPerfTopK(MSelValPerf):
1. Model selection using validation performance
2. Visitor pattern to trainer
"""

def __init__(self, max_es, top_k=2):
super().__init__(max_es) # construct self.tr_obs (observer)
self.top_k = top_k
Expand All @@ -20,32 +21,41 @@ def update(self, clear_counter=False):
if the best model should be updated
"""
flag_super = super().update(clear_counter)
metric_val_current = \
self.tr_obs.metric_val[self.tr_obs.str_metric4msel]
metric_val_current = self.tr_obs.metric_val[self.tr_obs.str_metric4msel]
acc_min = min(self.list_top_k_acc)
if metric_val_current > acc_min:
# overwrite
logger = Logger.get_logger()
logger.info(f"top k validation acc: {self.list_top_k_acc} \
overwriting/reset counter")
logger.info(
f"top k validation acc: {self.list_top_k_acc} \
overwriting/reset counter"
)
self.es_c = 0 # restore counter
ind = self.list_top_k_acc.index(acc_min)
# avoid having identical values
if metric_val_current not in self.list_top_k_acc:
self.list_top_k_acc[ind] = metric_val_current
logger.info(f"top k validation acc updated: \
{self.list_top_k_acc}")
logger.info(
f"top k validation acc updated: \
{self.list_top_k_acc}"
)
# overwrite to ensure consistency
# issue #569: initially self.list_top_k_acc will be [xx, 0] and it does not matter since 0 will be overwriten by second epoch validation acc.
# issue #569: initially self.list_top_k_acc will be [xx, 0] and it does not matter since 0 will be overwriten by second epoch validation acc.
# actually, after epoch 1, most often, sefl._best_val_acc will be the higher value of self.list_top_k_acc will overwriten by min(self.list_top_k_acc)
logger.info(f"top-2 val sel: overwriting best val acc from {self._best_val_acc} to minium of {self.list_top_k_acc} which is {min(self.list_top_k_acc)} to ensure consistency")
logger.info(
f"top-2 val sel: overwriting best val acc from {self._best_val_acc} to "
f"minimum of {self.list_top_k_acc} which is {min(self.list_top_k_acc)} "
f"to ensure consistency"
)
self._best_val_acc = min(self.list_top_k_acc)
# overwrite test acc, this does not depend on if val top-k acc has been overwritten or not
metric_te_current = \
self.tr_obs.metric_te[self.tr_obs.str_metric4msel]
metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel]
if self._sel_model_te_acc != metric_te_current:
# this can only happen if the validation acc has decreased and current val acc is only bigger than min(self.list_top_k_acc} but lower than max(self.list_top_k_acc)
logger.info(f"top-2 val sel: overwriting selected model test acc from {self._sel_model_te_acc} to {metric_te_current} to ensure consistency")
logger.info(
f"top-2 val sel: overwriting selected model test acc from "
f"{self._sel_model_te_acc} to {metric_te_current} to ensure consistency"
)
self._sel_model_te_acc = metric_te_current
return True
return flag_super
8 changes: 6 additions & 2 deletions domainlab/algos/observers/b_obvisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,12 @@ def after_all(self):
else:
metric_te.update({"acc_val": -1})

if hasattr(self, "model_sel") and hasattr(self.model_sel, "oracle_last_setpoint_sel_te_acc"):
metric_te.update({"acc_setpoint":self.model_sel.oracle_last_setpoint_sel_te_acc})
if hasattr(self, "model_sel") and hasattr(
self.model_sel, "oracle_last_setpoint_sel_te_acc"
):
metric_te.update(
{"acc_setpoint": self.model_sel.oracle_last_setpoint_sel_te_acc}
)
else:
metric_te.update({"acc_setpoint": -1})
self.dump_prediction(model_ld, metric_te)
Expand Down
2 changes: 1 addition & 1 deletion domainlab/algos/observers/c_obvisitor_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, observer):

def after_all(self):
self.observer.after_all()
self.observer.clean_up() # FIXME should be self.clean_up???
self.observer.clean_up() # FIXME should be self.clean_up???

def accept(self, trainer):
self.observer.accept(trainer)
Expand Down
3 changes: 1 addition & 2 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def __init__(self, successor_node=None, extend=None):
self.mu_iter_start = 0
self.flag_setpoint_updated = False


@property
def model(self):
"""
Expand Down Expand Up @@ -226,7 +225,7 @@ def get_model(self):
if "trainer" not in str(type(self._model)).lower():
return self._model
return self._model.get_model()

def as_model(self):
"""
used for decorator pattern
Expand Down
Loading