diff --git a/README.md b/README.md index ec7e3e88b..ed427eb7c 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ conda activate domainlab_py39 conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge conda install torchmetrics==0.10.3 git checkout fbopt -pip install -r requirements_notorch.txt +pip install -r requirements_notorch.txt conda install tensorboard ``` @@ -60,7 +60,7 @@ https://github.com/marrlab/DomainLab/blob/fbopt/data/script/download_pacs.py step 2: make a symbolic link following the example script in https://github.com/marrlab/DomainLab/blob/master/sh_pacs.sh -where `mkdir -p data/pacs` is executed under the repository directory, +where `mkdir -p data/pacs` is executed under the repository directory, `ln -s /dir/to/yourdata/pacs/raw ./data/pacs/PACS` will create a symbolic link under the repository directory diff --git a/domainlab/algos/builder_diva.py b/domainlab/algos/builder_diva.py index 94bbff33b..0baf9039a 100644 --- a/domainlab/algos/builder_diva.py +++ b/domainlab/algos/builder_diva.py @@ -2,10 +2,10 @@ Builder pattern to build different component for experiment with DIVA """ from domainlab.algos.a_algo_builder import NodeAlgoBuilder -from domainlab.algos.msels.c_msel_val import MSelValPerf -from domainlab.algos.msels.c_msel_val_top_k import MSelValPerfTopK from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor from domainlab.algos.msels.c_msel_setpoint_delay import MSelSetpointDelay +from domainlab.algos.msels.c_msel_val import MSelValPerf +from domainlab.algos.msels.c_msel_val_top_k import MSelValPerfTopK from domainlab.algos.observers.b_obvisitor import ObVisitor from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp from domainlab.algos.observers.c_obvisitor_gen import ObVisitorGen @@ -37,19 +37,23 @@ def init_business(self, exp): request = RequestVAEBuilderCHW(task.isize.c, task.isize.h, task.isize.w, args) node = VAEChainNodeGetter(request)() task.get_list_domains_tr_te(args.tr_d, args.te_d) - model = mk_diva(str_diva_multiplier_type=args.str_diva_multiplier_type)(node, - zd_dim=args.zd_dim, - zy_dim=args.zy_dim, - zx_dim=args.zx_dim, - list_str_y=task.list_str_y, - list_d_tr=task.list_domain_tr, - gamma_d=args.gamma_d, - gamma_y=args.gamma_y, - beta_x=args.beta_x, - beta_y=args.beta_y, - beta_d=args.beta_d) + model = mk_diva(str_diva_multiplier_type=args.str_diva_multiplier_type)( + node, + zd_dim=args.zd_dim, + zy_dim=args.zy_dim, + zx_dim=args.zx_dim, + list_str_y=task.list_str_y, + list_d_tr=task.list_domain_tr, + gamma_d=args.gamma_d, + gamma_y=args.gamma_y, + beta_x=args.beta_x, + beta_y=args.beta_y, + beta_d=args.beta_d, + ) device = get_device(args) - model_sel = MSelSetpointDelay(MSelOracleVisitor(MSelValPerfTopK(max_es=args.es))) + model_sel = MSelSetpointDelay( + MSelOracleVisitor(MSelValPerfTopK(max_es=args.es)) + ) if not args.gen: observer = ObVisitor(model_sel) else: diff --git a/domainlab/algos/builder_fbopt_dial.py b/domainlab/algos/builder_fbopt_dial.py index 81b5c3eae..f1faad96b 100644 --- a/domainlab/algos/builder_fbopt_dial.py +++ b/domainlab/algos/builder_fbopt_dial.py @@ -9,6 +9,7 @@ class NodeAlgoBuilderFbOptDial(NodeAlgoBuilderDIVA): """ builder for feedback optimization for dial """ + def init_business(self, exp): """ return trainer, model, observer diff --git a/domainlab/algos/builder_jigen1.py b/domainlab/algos/builder_jigen1.py index 4f0cae4f3..b14a3882f 100644 --- a/domainlab/algos/builder_jigen1.py +++ b/domainlab/algos/builder_jigen1.py @@ -2,10 +2,10 @@ builder for JiGen """ from domainlab.algos.a_algo_builder import NodeAlgoBuilder -from domainlab.algos.msels.c_msel_val import MSelValPerf -from domainlab.algos.msels.c_msel_val_top_k import MSelValPerfTopK from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor from domainlab.algos.msels.c_msel_setpoint_delay import MSelSetpointDelay +from domainlab.algos.msels.c_msel_val import MSelValPerf +from domainlab.algos.msels.c_msel_val_top_k import MSelValPerfTopK from domainlab.algos.observers.b_obvisitor import ObVisitor from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp from domainlab.algos.trainers.hyper_scheduler import HyperSchedulerWarmupExponential diff --git a/domainlab/algos/msels/a_model_sel.py b/domainlab/algos/msels/a_model_sel.py index cb5f09f75..1cbfffeaa 100644 --- a/domainlab/algos/msels/a_model_sel.py +++ b/domainlab/algos/msels/a_model_sel.py @@ -92,7 +92,7 @@ def best_te_metric(self): if self.msel is not None: return self.msel.best_te_metric return -1 - + @property def sel_model_te_acc(self): """ @@ -101,9 +101,9 @@ def sel_model_te_acc(self): if self.msel is not None: return self.msel.sel_model_te_acc return -1 - + @property def oracle_last_setpoint_sel_te_acc(self): if self.msel is not None: return self.msel.oracle_last_setpoint_sel_te_acc - return -1 \ No newline at end of file + return -1 diff --git a/domainlab/algos/msels/c_msel_setpoint_delay.py b/domainlab/algos/msels/c_msel_setpoint_delay.py index 650b57527..87c9b3c84 100644 --- a/domainlab/algos/msels/c_msel_setpoint_delay.py +++ b/domainlab/algos/msels/c_msel_setpoint_delay.py @@ -2,6 +2,7 @@ Multiobjective Model Selection """ import copy + from domainlab.algos.msels.a_model_sel import AMSel from domainlab.utils.logger import Logger @@ -11,27 +12,35 @@ class MSelSetpointDelay(AMSel): 1. Model selection using validation performance 2. Only update if setpoint has been decreased """ + def __init__(self, msel): super().__init__() # NOTE: super() has to come first always otherwise self.msel will be overwritten to be None self.msel = msel self._oracle_last_setpoint_sel_te_acc = 0.0 - + @property def oracle_last_setpoint_sel_te_acc(self): """ return the last setpoint best acc """ return self._oracle_last_setpoint_sel_te_acc - + def update(self, clear_counter=False): """ if the best model should be updated """ logger = Logger.get_logger() - logger.info(f"setpoint selected current acc {self._oracle_last_setpoint_sel_te_acc}") + logger.info( + f"setpoint selected current acc {self._oracle_last_setpoint_sel_te_acc}" + ) if clear_counter: - logger.info("setpoint msel te acc updated from {self._oracle_last_setpoint_sel_te_acc} to {self.sel_model_te_acc}") + log_message = ( + f"setpoint msel te acc updated from " + f"{self._oracle_last_setpoint_sel_te_acc} to " + f"{self.sel_model_te_acc}" + ) + logger.info(log_message) self._oracle_last_setpoint_sel_te_acc = self.sel_model_te_acc flag = self.msel.update(clear_counter) return flag diff --git a/domainlab/algos/msels/c_msel_tr_loss.py b/domainlab/algos/msels/c_msel_tr_loss.py index 9d7b2f5a0..7ed3f1168 100644 --- a/domainlab/algos/msels/c_msel_tr_loss.py +++ b/domainlab/algos/msels/c_msel_tr_loss.py @@ -27,7 +27,7 @@ def reset(self): @property def max_es(self): return self._max_es - + def update(self, clear_counter=False): """ if the best model should be updated diff --git a/domainlab/algos/msels/c_msel_val.py b/domainlab/algos/msels/c_msel_val.py index 766be2d07..c1f2f5561 100644 --- a/domainlab/algos/msels/c_msel_val.py +++ b/domainlab/algos/msels/c_msel_val.py @@ -24,7 +24,7 @@ def reset(self): @property def sel_model_te_acc(self): return self._sel_model_te_acc - + @property def best_val_acc(self): """ diff --git a/domainlab/algos/msels/c_msel_val_top_k.py b/domainlab/algos/msels/c_msel_val_top_k.py index 02687b0f5..4836a6245 100644 --- a/domainlab/algos/msels/c_msel_val_top_k.py +++ b/domainlab/algos/msels/c_msel_val_top_k.py @@ -10,6 +10,7 @@ class MSelValPerfTopK(MSelValPerf): 1. Model selection using validation performance 2. Visitor pattern to trainer """ + def __init__(self, max_es, top_k=2): super().__init__(max_es) # construct self.tr_obs (observer) self.top_k = top_k @@ -20,32 +21,41 @@ def update(self, clear_counter=False): if the best model should be updated """ flag_super = super().update(clear_counter) - metric_val_current = \ - self.tr_obs.metric_val[self.tr_obs.str_metric4msel] + metric_val_current = self.tr_obs.metric_val[self.tr_obs.str_metric4msel] acc_min = min(self.list_top_k_acc) if metric_val_current > acc_min: # overwrite logger = Logger.get_logger() - logger.info(f"top k validation acc: {self.list_top_k_acc} \ - overwriting/reset counter") + logger.info( + f"top k validation acc: {self.list_top_k_acc} \ + overwriting/reset counter" + ) self.es_c = 0 # restore counter ind = self.list_top_k_acc.index(acc_min) # avoid having identical values if metric_val_current not in self.list_top_k_acc: self.list_top_k_acc[ind] = metric_val_current - logger.info(f"top k validation acc updated: \ - {self.list_top_k_acc}") + logger.info( + f"top k validation acc updated: \ + {self.list_top_k_acc}" + ) # overwrite to ensure consistency - # issue #569: initially self.list_top_k_acc will be [xx, 0] and it does not matter since 0 will be overwriten by second epoch validation acc. + # issue #569: initially self.list_top_k_acc will be [xx, 0] and it does not matter since 0 will be overwriten by second epoch validation acc. # actually, after epoch 1, most often, sefl._best_val_acc will be the higher value of self.list_top_k_acc will overwriten by min(self.list_top_k_acc) - logger.info(f"top-2 val sel: overwriting best val acc from {self._best_val_acc} to minium of {self.list_top_k_acc} which is {min(self.list_top_k_acc)} to ensure consistency") + logger.info( + f"top-2 val sel: overwriting best val acc from {self._best_val_acc} to " + f"minimum of {self.list_top_k_acc} which is {min(self.list_top_k_acc)} " + f"to ensure consistency" + ) self._best_val_acc = min(self.list_top_k_acc) # overwrite test acc, this does not depend on if val top-k acc has been overwritten or not - metric_te_current = \ - self.tr_obs.metric_te[self.tr_obs.str_metric4msel] + metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel] if self._sel_model_te_acc != metric_te_current: # this can only happen if the validation acc has decreased and current val acc is only bigger than min(self.list_top_k_acc} but lower than max(self.list_top_k_acc) - logger.info(f"top-2 val sel: overwriting selected model test acc from {self._sel_model_te_acc} to {metric_te_current} to ensure consistency") + logger.info( + f"top-2 val sel: overwriting selected model test acc from " + f"{self._sel_model_te_acc} to {metric_te_current} to ensure consistency" + ) self._sel_model_te_acc = metric_te_current return True return flag_super diff --git a/domainlab/algos/observers/b_obvisitor.py b/domainlab/algos/observers/b_obvisitor.py index 457ab874c..6dce73c28 100644 --- a/domainlab/algos/observers/b_obvisitor.py +++ b/domainlab/algos/observers/b_obvisitor.py @@ -119,8 +119,12 @@ def after_all(self): else: metric_te.update({"acc_val": -1}) - if hasattr(self, "model_sel") and hasattr(self.model_sel, "oracle_last_setpoint_sel_te_acc"): - metric_te.update({"acc_setpoint":self.model_sel.oracle_last_setpoint_sel_te_acc}) + if hasattr(self, "model_sel") and hasattr( + self.model_sel, "oracle_last_setpoint_sel_te_acc" + ): + metric_te.update( + {"acc_setpoint": self.model_sel.oracle_last_setpoint_sel_te_acc} + ) else: metric_te.update({"acc_setpoint": -1}) self.dump_prediction(model_ld, metric_te) diff --git a/domainlab/algos/observers/c_obvisitor_cleanup.py b/domainlab/algos/observers/c_obvisitor_cleanup.py index 2632ea46f..4de3ef6b4 100644 --- a/domainlab/algos/observers/c_obvisitor_cleanup.py +++ b/domainlab/algos/observers/c_obvisitor_cleanup.py @@ -12,7 +12,7 @@ def __init__(self, observer): def after_all(self): self.observer.after_all() - self.observer.clean_up() # FIXME should be self.clean_up??? + self.observer.clean_up() # FIXME should be self.clean_up??? def accept(self, trainer): self.observer.accept(trainer) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index d9f80f331..a82c53d99 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -89,7 +89,6 @@ def __init__(self, successor_node=None, extend=None): self.mu_iter_start = 0 self.flag_setpoint_updated = False - @property def model(self): """ @@ -226,7 +225,7 @@ def get_model(self): if "trainer" not in str(type(self._model)).lower(): return self._model return self._model.get_model() - + def as_model(self): """ used for decorator pattern diff --git a/domainlab/algos/trainers/args_fbopt.py b/domainlab/algos/trainers/args_fbopt.py index 2039ce2d4..53719e05f 100644 --- a/domainlab/algos/trainers/args_fbopt.py +++ b/domainlab/algos/trainers/args_fbopt.py @@ -8,67 +8,119 @@ def add_args2parser_fbopt(parser): append hyper-parameters to the main argparser """ - parser.add_argument('--k_i_gain', type=float, default=0.001, - help='PID control gain for integrator') - - parser.add_argument('--k_i_gain_ratio', type=float, default=None, - help='set k_i_gain to be ratio of \ - initial saturation k_i_gain') - - parser.add_argument('--mu_clip', type=float, default=1e4, - help='maximum value of mu') - - parser.add_argument('--mu_min', type=float, default=1e-6, - help='minimum value of mu') - - parser.add_argument('--mu_init', type=float, default=0.001, - help='initial beta for multiplication') - - parser.add_argument('--coeff_ma', type=float, default=0.5, - help='exponential moving average') - - parser.add_argument('--coeff_ma_output_state', type=float, default=0.1, - help='state exponential moving average of \ - reguarlization loss') - - parser.add_argument('--coeff_ma_setpoint', type=float, default=0.9, - help='setpoint average coeff for previous setpoint') - - parser.add_argument('--exp_shoulder_clip', type=float, default=5, - help='clip before exponential operation') - - parser.add_argument('--ini_setpoint_ratio', type=float, default=0.99, - help='before training start, evaluate reg loss, \ - setpoint will be 0.9 of this loss') - - parser.add_argument('--force_feedforward', action='store_true', - default=False, - help='use feedforward scheduler') - - parser.add_argument('--force_setpoint_change_once', action='store_true', - default=False, - help='train until the setpoint changed at least once \ - up to maximum epos specified') - - parser.add_argument('--no_tensorboard', action='store_true', default=False, - help='disable tensorboard') - - parser.add_argument('--no_setpoint_update', action='store_true', - default=False, - help='disable setpoint update') - - parser.add_argument('--tr_with_init_mu', action='store_true', - default=False, - help='disable setpoint update') - - parser.add_argument('--overshoot_rewind', type=str, default="yes", - help='overshoot_rewind, for benchmark, use yes or no') - - parser.add_argument('--setpoint_rewind', type=str, default="no", - help='setpoing_rewind, for benchmark, use yes or no') - - parser.add_argument('--str_diva_multiplier_type', type=str, - default="gammad_recon", - help='which penalty to tune') + parser.add_argument( + "--k_i_gain", type=float, default=0.001, help="PID control gain for integrator" + ) + + parser.add_argument( + "--k_i_gain_ratio", + type=float, + default=None, + help="set k_i_gain to be ratio of \ + initial saturation k_i_gain", + ) + + parser.add_argument( + "--mu_clip", type=float, default=1e4, help="maximum value of mu" + ) + + parser.add_argument( + "--mu_min", type=float, default=1e-6, help="minimum value of mu" + ) + + parser.add_argument( + "--mu_init", type=float, default=0.001, help="initial beta for multiplication" + ) + + parser.add_argument( + "--coeff_ma", type=float, default=0.5, help="exponential moving average" + ) + + parser.add_argument( + "--coeff_ma_output_state", + type=float, + default=0.1, + help="state exponential moving average of \ + reguarlization loss", + ) + + parser.add_argument( + "--coeff_ma_setpoint", + type=float, + default=0.9, + help="setpoint average coeff for previous setpoint", + ) + + parser.add_argument( + "--exp_shoulder_clip", + type=float, + default=5, + help="clip before exponential operation", + ) + + parser.add_argument( + "--ini_setpoint_ratio", + type=float, + default=0.99, + help="before training start, evaluate reg loss, \ + setpoint will be 0.9 of this loss", + ) + + parser.add_argument( + "--force_feedforward", + action="store_true", + default=False, + help="use feedforward scheduler", + ) + + parser.add_argument( + "--force_setpoint_change_once", + action="store_true", + default=False, + help="train until the setpoint changed at least once \ + up to maximum epos specified", + ) + + parser.add_argument( + "--no_tensorboard", + action="store_true", + default=False, + help="disable tensorboard", + ) + + parser.add_argument( + "--no_setpoint_update", + action="store_true", + default=False, + help="disable setpoint update", + ) + + parser.add_argument( + "--tr_with_init_mu", + action="store_true", + default=False, + help="disable setpoint update", + ) + + parser.add_argument( + "--overshoot_rewind", + type=str, + default="yes", + help="overshoot_rewind, for benchmark, use yes or no", + ) + + parser.add_argument( + "--setpoint_rewind", + type=str, + default="no", + help="setpoing_rewind, for benchmark, use yes or no", + ) + + parser.add_argument( + "--str_diva_multiplier_type", + type=str, + default="gammad_recon", + help="which penalty to tune", + ) return parser diff --git a/domainlab/algos/trainers/compos/matchdg_match.py b/domainlab/algos/trainers/compos/matchdg_match.py index 78e67abde..8c6b46c90 100644 --- a/domainlab/algos/trainers/compos/matchdg_match.py +++ b/domainlab/algos/trainers/compos/matchdg_match.py @@ -16,6 +16,7 @@ class MatchPair: """ match different input """ + @store_args def __init__( self, diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index 083295e45..9f8e02971 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -4,15 +4,17 @@ import os import warnings +import numpy as np from torch.utils.tensorboard import SummaryWriter -import numpy as np +from domainlab.algos.trainers.fbopt_setpoint_ada import ( + FbOptSetpointController, + if_list_sign_agree, +) from domainlab.utils.logger import Logger -from domainlab.algos.trainers.fbopt_setpoint_ada import FbOptSetpointController -from domainlab.algos.trainers.fbopt_setpoint_ada import if_list_sign_agree -class StubSummaryWriter(): +class StubSummaryWriter: """ # stub writer for tensorboard that ignores all messages """ @@ -28,10 +30,12 @@ def add_scalars(self, *args, **kwargs): """ -class HyperSchedulerFeedback(): +class HyperSchedulerFeedback: + # pylint: disable=too-many-instance-attributes """ design $\\mu$$ sequence based on state of penalized loss """ + def __init__(self, trainer, **kwargs): """ kwargs is a dictionary with key the hyper-parameter name and its value @@ -44,8 +48,7 @@ def __init__(self, trainer, **kwargs): self.mmu = kwargs # force initial value of mu self.mmu = {key: self.init_mu for key, val in self.mmu.items()} - self.set_point_controller = FbOptSetpointController( - args=self.trainer.aconf) + self.set_point_controller = FbOptSetpointController(args=self.trainer.aconf) self.k_i_control = trainer.aconf.k_i_gain self.k_i_gain_ratio = None @@ -62,7 +65,7 @@ def __init__(self, trainer, **kwargs): if trainer.aconf.no_tensorboard: self.writer = StubSummaryWriter() else: - str_job_id = os.environ.get('SLURM_JOB_ID', '') + str_job_id = os.environ.get("SLURM_JOB_ID", "") self.writer = SummaryWriter(comment=str_job_id) def set_k_i_gain(self, epo_reg_loss): @@ -75,14 +78,17 @@ def set_k_i_gain(self, epo_reg_loss): delta_epsilon_r = [a - b for a, b in zip(epo_reg_loss, list_setpoint)] # to calculate self.delta_epsilon_r - k_i_gain_saturate = [a / b for a, b in - zip(self.activation_clip, delta_epsilon_r)] + k_i_gain_saturate = [ + a / b for a, b in zip(self.activation_clip, delta_epsilon_r) + ] k_i_gain_saturate_min = min(k_i_gain_saturate) # NOTE: here we override the commandline arguments specification # for k_i_control, so k_i_control is not a hyperparameter anymore self.k_i_control = self.k_i_gain_ratio * k_i_gain_saturate_min - warnings.warn(f"hyperparameter k_i_gain disabled! \ - replace with {self.k_i_control}") + warnings.warn( + f"hyperparameter k_i_gain disabled! \ + replace with {self.k_i_control}" + ) # FIXME: change this to 1-self.ini_setpoint_ratio, i.e. the more # difficult the initial setpoint is, the bigger the k_i_gain should be @@ -112,38 +118,42 @@ def cal_delta4control(self, list1, list_setpoint): # self.delta_epsilon_r is the previous time step. # delta_epsilon_r is the current time step self.delta_epsilon_r = self.cal_delta_integration( - self.delta_epsilon_r, delta_epsilon_r, self.coeff_ma) + self.delta_epsilon_r, delta_epsilon_r, self.coeff_ma + ) def cal_delta_integration(self, list_old, list_new, coeff): """ ma of delta """ - return [(1 - coeff) * a + coeff * b - for a, b in zip(list_old, list_new)] + return [(1 - coeff) * a + coeff * b for a, b in zip(list_old, list_new)] - def tackle_overshoot(self, - activation, epo_reg_loss, list_str_multiplier_na): + def tackle_overshoot(self, activation, epo_reg_loss, list_str_multiplier_na): """ tackle overshoot """ - list_overshoot = [i if (a - b) * (self.delta_epsilon_r[i]) < 0 - else None - for i, (a, b) in - enumerate( - zip(epo_reg_loss, - self.set_point_controller.setpoint4R))] + list_overshoot = [ + i if (a - b) * (self.delta_epsilon_r[i]) < 0 else None + for i, (a, b) in enumerate( + zip(epo_reg_loss, self.set_point_controller.setpoint4R) + ) + ] for ind in list_overshoot: if ind is not None: logger = Logger.get_logger( - logger_name='main_out_logger', loglevel="INFO") + logger_name="main_out_logger", loglevel="INFO" + ) logger.info(f"delta integration: {self.delta_epsilon_r}") - logger.info(f"overshooting at pos \ - {ind} of activation: {activation}") + logger.info( + f"overshooting at pos \ + {ind} of activation: {activation}" + ) logger.info(f"name reg loss:{list_str_multiplier_na}") if self.overshoot_rewind: activation[ind] = 0.0 - logger.info(f"PID controller set to zero now, \ - new activation: {activation}") + logger.info( + f"PID controller set to zero now, \ + new activation: {activation}" + ) return activation def cal_activation(self): @@ -151,26 +161,30 @@ def cal_activation(self): calculate activation on exponential shoulder """ setpoint = self.get_setpoint4r() - activation = [self.k_i_control * val if setpoint[i] > 0 - else self.k_i_control * (-val) for i, val - in enumerate(self.delta_epsilon_r)] + activation = [ + self.k_i_control * val if setpoint[i] > 0 else self.k_i_control * (-val) + for i, val in enumerate(self.delta_epsilon_r) + ] if self.activation_clip is not None: - activation = [np.clip(val, - a_min=-1 * self.activation_clip, - a_max=self.activation_clip) - for val in activation] + activation = [ + np.clip( + val, a_min=-1 * self.activation_clip, a_max=self.activation_clip + ) + for val in activation + ] return activation - def search_mu(self, epo_reg_loss, epo_task_loss, epo_loss_tr, - list_str_multiplier_na, miter): + def search_mu( + self, epo_reg_loss, epo_task_loss, epo_loss_tr, list_str_multiplier_na, miter + ): + # pylint: disable=too-many-locals, too-many-arguments """ start from parameter dictionary dict_theta: {"layer":tensor}, enlarge mu w.r.t. its current value to see if the criteria is met $$\\mu^{k+1}=mu^{k}exp(rate_mu*[R(\\theta^{k})-ref_R])$$ """ - logger = Logger.get_logger( - logger_name='main_out_logger', loglevel="INFO") + logger = Logger.get_logger(logger_name="main_out_logger", loglevel="INFO") logger.info(f"before controller: current mu: {self.mmu}") logger.info(f"epo reg loss: {epo_reg_loss}") logger.info(f"name reg loss:{list_str_multiplier_na}") @@ -178,39 +192,43 @@ def search_mu(self, epo_reg_loss, epo_task_loss, epo_loss_tr, activation = self.cal_activation() # overshoot handling activation = self.tackle_overshoot( - activation, epo_reg_loss, list_str_multiplier_na) + activation, epo_reg_loss, list_str_multiplier_na + ) list_gain = np.exp(activation) dict_gain = dict(zip(list_str_multiplier_na, list_gain)) target = self.dict_multiply(self.mmu, dict_gain) self.mmu = self.dict_clip(target) - logger = Logger.get_logger( - logger_name='main_out_logger', loglevel="INFO") + logger = Logger.get_logger(logger_name="main_out_logger", loglevel="INFO") logger.info(f"after contoller: current mu: {self.mmu}") for key, val in self.mmu.items(): - self.writer.add_scalar(f'dyn_mu/{key}', val, miter) - self.writer.add_scalar( - f'controller_gain/{key}', dict_gain[key], miter) + self.writer.add_scalar(f"dyn_mu/{key}", val, miter) + self.writer.add_scalar(f"controller_gain/{key}", dict_gain[key], miter) ind = list_str_multiplier_na.index(key) + self.writer.add_scalar(f"delta/{key}", self.delta_epsilon_r[ind], miter) + for i, (reg_dyn, reg_set) in enumerate( + zip(epo_reg_loss, self.get_setpoint4r()) + ): self.writer.add_scalar( - f'delta/{key}', self.delta_epsilon_r[ind], miter) - for i, (reg_dyn, reg_set) in \ - enumerate(zip(epo_reg_loss, self.get_setpoint4r())): + f"lossrd/dyn_{list_str_multiplier_na[i]}", reg_dyn, miter + ) self.writer.add_scalar( - f'lossrd/dyn_{list_str_multiplier_na[i]}', reg_dyn, miter) - self.writer.add_scalar( - f'lossrs/setpoint_{list_str_multiplier_na[i]}', reg_set, miter) + f"lossrs/setpoint_{list_str_multiplier_na[i]}", reg_set, miter + ) self.writer.add_scalars( - f'loss_rds/loss_{list_str_multiplier_na[i]}_w_setpoint', - {f'lossr/loss_{list_str_multiplier_na[i]}': reg_dyn, - f'lossr/setpoint_{list_str_multiplier_na[i]}': reg_set, - }, miter) + f"loss_rds/loss_{list_str_multiplier_na[i]}_w_setpoint", + { + f"lossr/loss_{list_str_multiplier_na[i]}": reg_dyn, + f"lossr/setpoint_{list_str_multiplier_na[i]}": reg_set, + }, + miter, + ) self.writer.add_scalar( - f'x_ell_y_r/loss_{list_str_multiplier_na[i]}', - reg_dyn, epo_task_loss) - self.writer.add_scalar('loss_task/penalized', epo_loss_tr, miter) - self.writer.add_scalar('loss_task/ell', epo_task_loss, miter) + f"x_ell_y_r/loss_{list_str_multiplier_na[i]}", reg_dyn, epo_task_loss + ) + self.writer.add_scalar("loss_task/penalized", epo_loss_tr, miter) + self.writer.add_scalar("loss_task/ell", epo_task_loss, miter) acc_te = 0 acc_val = 0 acc_sel = 0 @@ -220,8 +238,7 @@ def search_mu(self, epo_reg_loss, epo_task_loss, epo_loss_tr, acc_te = self.trainer.observer.metric_te["acc"] acc_val = self.trainer.observer.metric_val["acc"] acc_sel = self.trainer.observer.model_sel.sel_model_te_acc - acc_set = \ - self.trainer.observer.model_sel.oracle_last_setpoint_sel_te_acc + acc_set = self.trainer.observer.model_sel.oracle_last_setpoint_sel_te_acc self.writer.add_scalar("acc/te", acc_te, miter) self.writer.add_scalar("acc/val", acc_val, miter) self.writer.add_scalar("acc/sel", acc_sel, miter) @@ -231,8 +248,10 @@ def dict_clip(self, dict_base): """ clip each entry of the mu according to pre-set self.mu_clip """ - return {key: np.clip(val, a_min=self.mu_min, a_max=self.mu_clip) - for key, val in dict_base.items()} + return { + key: np.clip(val, a_min=self.mu_min, a_max=self.mu_clip) + for key, val in dict_base.items() + } def dict_is_zero(self, dict_mu): """ @@ -247,8 +266,7 @@ def dict_multiply(self, dict_base, dict_multiplier): """ multiply a float to each element of a dictionary """ - return { - key: val * dict_multiplier[key] for key, val in dict_base.items()} + return {key: val * dict_multiplier[key] for key, val in dict_base.items()} def update_setpoint(self, epo_reg_loss, epo_task_loss): """ diff --git a/domainlab/algos/trainers/fbopt_setpoint_ada.py b/domainlab/algos/trainers/fbopt_setpoint_ada.py index 2004a1ee2..c3c0193ce 100644 --- a/domainlab/algos/trainers/fbopt_setpoint_ada.py +++ b/domainlab/algos/trainers/fbopt_setpoint_ada.py @@ -2,6 +2,7 @@ update hyper-parameters during training """ import numpy as np + from domainlab.utils.logger import Logger @@ -41,8 +42,9 @@ def is_less_list_any(list1, list2): judge if one list is less than the other """ if_list_sign_agree(list1, list2) - list_comparison = [a < b if a >= 0 and b >= 0 else a > b - for a, b in zip(list1, list2)] + list_comparison = [ + a < b if a >= 0 and b >= 0 else a > b for a, b in zip(list1, list2) + ] return any(list_comparison), list_true(list_comparison) @@ -51,11 +53,13 @@ def is_less_list_all(list1, list2, flag_eq=False): judge if one list is less than the other """ if_list_sign_agree(list1, list2) - list_comparison = [a < b if a >= 0 and b >= 0 else a > b - for a, b in zip(list1, list2)] + list_comparison = [ + a < b if a >= 0 and b >= 0 else a > b for a, b in zip(list1, list2) + ] if flag_eq: - list_comparison = [a <= b if a >= 0 and b >= 0 else a >= b - for a, b in zip(list1, list2)] + list_comparison = [ + a <= b if a >= 0 and b >= 0 else a >= b for a, b in zip(list1, list2) + ] return all(list_comparison) @@ -63,15 +67,15 @@ def list_ma(list_state, list_input, coeff): """ moving average of list """ - return [a * coeff + b * (1 - coeff) for a, b in \ - zip(list_state, list_input)] + return [a * coeff + b * (1 - coeff) for a, b in zip(list_state, list_input)] -class SetpointRewinder(): +class SetpointRewinder: """ rewind setpoint if current loss exponential moving average is bigger than setpoint """ + def __init__(self, host): self.host = host self.counter = None @@ -95,10 +99,10 @@ def observe(self, epo_reg_loss): if self.ref is None: self.reset(epo_reg_loss) self.epo_ma = list_ma(self.epo_ma, epo_reg_loss, self.coeff_ma) - list_comparison_increase = \ - [a < b for a, b in zip(self.ref, self.epo_ma)] - list_comparison_above_setpoint = \ - [a < b for a, b in zip(self.host.setpoint4R, self.epo_ma)] + list_comparison_increase = [a < b for a, b in zip(self.ref, self.epo_ma)] + list_comparison_above_setpoint = [ + a < b for a, b in zip(self.host.setpoint4R, self.epo_ma) + ] flag_increase = any(list_comparison_increase) flag_above_setpoint = any(list_comparison_above_setpoint) if flag_increase and flag_above_setpoint: @@ -114,9 +118,11 @@ def observe(self, epo_reg_loss): list_pos = list_true(list_comparison_above_setpoint) print(f"\n\n\n!!!!!!!setpoint too low at {list_pos}!\n\n\n") for pos in list_pos: - print(f"\n\n\n!!!!!!!rewinding setpoint at pos {pos} \ + print( + f"\n\n\n!!!!!!!rewinding setpoint at pos {pos} \ from {self.host.setpoint4R[pos]} to \ - {self.epo_ma[pos]}!\n\n\n") + {self.epo_ma[pos]}!\n\n\n" + ) self.host.setpoint4R[pos] = self.epo_ma[pos] if self.counter > 3: @@ -124,10 +130,12 @@ def observe(self, epo_reg_loss): self.counter = np.inf # FIXME -class FbOptSetpointController(): +class FbOptSetpointController: + # pylint: disable=too-many-instance-attributes """ update setpoint for mu """ + def __init__(self, state=None, args=None): """ kwargs is a dictionary with key the hyper-parameter name and its value @@ -141,7 +149,9 @@ def __init__(self, state=None, args=None): self.flag_setpoint_rewind = args.setpoint_rewind == "yes" self.setpoint_rewinder = SetpointRewinder(self) self.state_task_loss = 0.0 - self.state_epo_reg_loss = [0.0 for _ in range(10)] # FIXME: 10 is the maximum number losses here + self.state_epo_reg_loss = [ + 0.0 for _ in range(10) + ] # FIXME: 10 is the maximum number losses here self.coeff_ma_setpoint = args.coeff_ma_setpoint self.coeff_ma_output = args.coeff_ma_output_state # initial value will be set via trainer @@ -160,31 +170,34 @@ def update_setpoint_ma(self, list_target, list_pos): """ using moving average """ - target_ma = [self.coeff_ma_setpoint * a + - (1 - self.coeff_ma_setpoint) * b - for a, b in zip(self.setpoint4R, list_target)] - self.setpoint4R = [target_ma[i] if i in list_pos else - self.setpoint4R[i] for i in range(len(target_ma))] + target_ma = [ + self.coeff_ma_setpoint * a + (1 - self.coeff_ma_setpoint) * b + for a, b in zip(self.setpoint4R, list_target) + ] + self.setpoint4R = [ + target_ma[i] if i in list_pos else self.setpoint4R[i] + for i in range(len(target_ma)) + ] def observe(self, epo_reg_loss, epo_task_loss): """ read current epo_reg_loss continuously """ - self.state_epo_reg_loss = [self.coeff_ma_output*a + \ - (1-self.coeff_ma_output)*b - if a != 0.0 else b - for a, b in zip( - self.state_epo_reg_loss, epo_reg_loss)] + self.state_epo_reg_loss = [ + self.coeff_ma_output * a + (1 - self.coeff_ma_output) * b if a != 0.0 else b + for a, b in zip(self.state_epo_reg_loss, epo_reg_loss) + ] if self.state_task_loss == 0.0: self.state_task_loss = epo_task_loss - self.state_task_loss = self.coeff_ma_output * self.state_task_loss + \ - (1 - self.coeff_ma_output) * epo_task_loss + self.state_task_loss = ( + self.coeff_ma_output * self.state_task_loss + + (1 - self.coeff_ma_output) * epo_task_loss + ) self.setpoint_rewinder.observe(self.state_epo_reg_loss) flag_update, list_pos = self.state_updater.update_setpoint() if flag_update: self.setpoint_rewinder.reset(self.state_epo_reg_loss) - logger = Logger.get_logger( - logger_name='main_out_logger', loglevel="INFO") + logger = Logger.get_logger(logger_name="main_out_logger", loglevel="INFO") logger.info(f"!!!!!set point old value {self.setpoint4R}!") self.update_setpoint_ma(self.state_epo_reg_loss, list_pos) logger.info(f"!!!!!set point updated to {self.setpoint4R}!") @@ -192,13 +205,14 @@ def observe(self, epo_reg_loss, epo_task_loss): return False -class FbOptSetpointControllerState(): +class FbOptSetpointControllerState: + # pylint: disable=too-few-public-methods """ abstract state pattern """ + def __init__(self): - """ - """ + """ """ self.host = None def accept(self, controller): @@ -212,6 +226,7 @@ class FixedSetpoint(FbOptSetpointControllerState): """ do not update setpoint """ + def update_setpoint(self): """ always return False so setpoint no update @@ -223,19 +238,23 @@ class SliderAllComponent(FbOptSetpointControllerState): """ concrete state pattern """ + def update_setpoint(self): """ all components of R descreases regardless if ell decreases or not """ - logger = Logger.get_logger( - logger_name='main_out_logger', loglevel="INFO") + logger = Logger.get_logger(logger_name="main_out_logger", loglevel="INFO") logger.info( f"comparing output vs setpoint: \n \ {self.host.state_epo_reg_loss} \n \ - {self.host.setpoint4R}") - if is_less_list_all(self.host.state_epo_reg_loss, - self.host.setpoint4R, flag_eq=True): - logger.info("!!!!!!!!!In SliderAllComponent: R current value better than current setpoint!") + {self.host.setpoint4R}" + ) + if is_less_list_all( + self.host.state_epo_reg_loss, self.host.setpoint4R, flag_eq=True + ): + logger.info( + "!!!!!!!!!In SliderAllComponent: R current value better than current setpoint!" + ) return True, list(range(len(self.host.setpoint4R))) return False, None @@ -244,12 +263,14 @@ class SliderAnyComponent(FbOptSetpointControllerState): """ concrete state pattern """ + def update_setpoint(self): """ if any component of R has decreased regardless if ell decreases """ flag, list_pos = is_less_list_any( - self.host.state_epo_reg_loss, self.host.setpoint4R) + self.host.state_epo_reg_loss, self.host.setpoint4R + ) return flag, list_pos def transit(self): @@ -260,6 +281,7 @@ class DominateAnyComponent(SliderAnyComponent): """ concrete state pattern """ + def update_setpoint(self): """ if any of the component of R loss has decreased together with ell loss @@ -275,6 +297,7 @@ class DominateAllComponent(SliderAllComponent): """ concrete state pattern """ + def update_setpoint(self): """ if each component of R loss has decreased and ell loss also decreased @@ -282,10 +305,10 @@ def update_setpoint(self): flag1, list_pos = super().update_setpoint() flag2 = self.host.state_task_loss < self.host.setpoint4ell if flag2: - logger = Logger.get_logger( - logger_name='main_out_logger', loglevel="INFO") + logger = Logger.get_logger(logger_name="main_out_logger", loglevel="INFO") logger.info( f"best ell loss: from {self.host.setpoint4ell} to \ - {self.host.state_task_loss}") + {self.host.state_task_loss}" + ) self.host.setpoint4ell = self.host.state_task_loss return flag1 & flag2, list_pos diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index ce8eb20f2..02f8e02fe 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -38,10 +38,10 @@ def before_epoch(self): def tr_epoch(self, epoch, flag_info=False): self.before_epoch() - for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in \ - enumerate(self.loader_tr): - self.tr_batch(tensor_x, tensor_y, tensor_d, others, - ind_batch, epoch) + for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( + self.loader_tr + ): + self.tr_batch(tensor_x, tensor_y, tensor_d, others, ind_batch, epoch) return self.after_epoch(epoch, flag_info) def after_epoch(self, epoch, flag_info): diff --git a/domainlab/algos/trainers/train_dial.py b/domainlab/algos/trainers/train_dial.py index 1faa4d317..76f2dff02 100644 --- a/domainlab/algos/trainers/train_dial.py +++ b/domainlab/algos/trainers/train_dial.py @@ -51,7 +51,6 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): loss_dial = self.model.cal_task_loss(tensor_x_batch_adv_no_grad, tensor_y) return [loss_dial], [self.aconf.gamma_reg] - def hyper_init(self, functor_scheduler, trainer): """ initialize both trainer's multiplier and model's multiplier diff --git a/domainlab/algos/trainers/train_fbopt_b.py b/domainlab/algos/trainers/train_fbopt_b.py index 21bb43981..7258453ba 100644 --- a/domainlab/algos/trainers/train_fbopt_b.py +++ b/domainlab/algos/trainers/train_fbopt_b.py @@ -2,10 +2,12 @@ update hyper-parameters during training """ from operator import add + import torch -from domainlab.algos.trainers.train_basic import TrainerBasic + from domainlab.algos.trainers.fbopt_mu_controller import HyperSchedulerFeedback from domainlab.algos.trainers.hyper_scheduler import HyperSchedulerWarmupLinear +from domainlab.algos.trainers.train_basic import TrainerBasic from domainlab.utils.logger import Logger @@ -16,10 +18,12 @@ def list_divide(list_val, scalar): return [ele / scalar for ele in list_val] -class HyperSetter(): +class HyperSetter: + # pylint: disable=too-few-public-methods """ mock object to force hyper-parameter in the model """ + def __init__(self, dict_hyper): self.dict_hyper = dict_hyper @@ -31,6 +35,7 @@ class TrainerFbOpt(TrainerBasic): """ TrainerHyperScheduler """ + def set_scheduler(self, scheduler): """ Args: @@ -53,24 +58,37 @@ def eval_r_loss(self): epo_p_loss = 0 counter = 0.0 with torch.no_grad(): - for _, (tensor_x, vec_y, vec_d, *others) in enumerate(self.loader_tr_no_drop): - tensor_x, vec_y, vec_d = \ - tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device) + for _, (tensor_x, vec_y, vec_d, *others) in enumerate( + self.loader_tr_no_drop + ): + tensor_x, vec_y, vec_d = ( + tensor_x.to(self.device), + vec_y.to(self.device), + vec_d.to(self.device), + ) tuple_reg_loss = self.model.cal_reg_loss(tensor_x, vec_y, vec_d, others) p_loss, *_ = self.model.cal_loss(tensor_x, vec_y, vec_d, others) # NOTE: first [0] extract the loss, second [0] get the list list_b_reg_loss = tuple_reg_loss[0] - list_b_reg_loss_sumed = [ele.sum().detach().item() for ele in list_b_reg_loss] + list_b_reg_loss_sumed = [ + ele.sum().detach().item() for ele in list_b_reg_loss + ] if len(epo_reg_loss) == 0: epo_reg_loss = list_b_reg_loss_sumed else: epo_reg_loss = list(map(add, epo_reg_loss, list_b_reg_loss_sumed)) - b_task_loss = self.model.cal_task_loss(tensor_x, vec_y).sum().detach().item() + b_task_loss = ( + self.model.cal_task_loss(tensor_x, vec_y).sum().detach().item() + ) # sum will kill the dimension of the mini batch epo_task_loss += b_task_loss epo_p_loss += p_loss.sum().detach().item() counter += 1.0 - return list_divide(epo_reg_loss, counter), epo_task_loss / counter, epo_p_loss / counter + return ( + list_divide(epo_reg_loss, counter), + epo_task_loss / counter, + epo_p_loss / counter, + ) def before_batch(self, epoch, ind_batch): """ @@ -79,7 +97,9 @@ def before_batch(self, epoch, ind_batch): """ if self.flag_update_hyper_per_batch: # NOTE: if not update per_batch, then not updated - self.model.hyper_update(epoch * self.num_batches + ind_batch, self.hyper_scheduler) + self.model.hyper_update( + epoch * self.num_batches + ind_batch, self.hyper_scheduler + ) return super().after_batch(epoch, ind_batch) def before_tr(self): @@ -93,11 +113,20 @@ def before_tr(self): if self.aconf.tr_with_init_mu: self.tr_with_init_mu() - self.epo_reg_loss_tr, self.epo_task_loss_tr, self.epo_loss_tr = self.eval_r_loss() + ( + self.epo_reg_loss_tr, + self.epo_task_loss_tr, + self.epo_loss_tr, + ) = self.eval_r_loss() self.hyper_scheduler.set_setpoint( - [ele * self.aconf.ini_setpoint_ratio if ele > 0 else - ele / self.aconf.ini_setpoint_ratio for ele in self.epo_reg_loss_tr], - self.epo_task_loss_tr) # setpoing w.r.t. random initialization of neural network + [ + ele * self.aconf.ini_setpoint_ratio + if ele > 0 + else ele / self.aconf.ini_setpoint_ratio + for ele in self.epo_reg_loss_tr + ], + self.epo_task_loss_tr, + ) # setpoing w.r.t. random initialization of neural network self.hyper_scheduler.set_k_i_gain(self.epo_reg_loss_tr) @property @@ -117,7 +146,9 @@ def set_model_with_mu(self): """ set model multipliers """ - self.model.hyper_update(epoch=None, fun_scheduler=HyperSetter(self.hyper_scheduler.mmu)) + self.model.hyper_update( + epoch=None, fun_scheduler=HyperSetter(self.hyper_scheduler.mmu) + ) def tr_epoch(self, epoch, flag_info=False): """ @@ -128,7 +159,8 @@ def tr_epoch(self, epoch, flag_info=False): self.epo_task_loss_tr, self.epo_loss_tr, self.list_str_multiplier_na, - miter=epoch) + miter=epoch, + ) self.set_model_with_mu() if hasattr(self.model, "dict_multiplier"): logger = Logger.get_logger() @@ -137,5 +169,6 @@ def tr_epoch(self, epoch, flag_info=False): flag = super().tr_epoch(epoch, self.flag_setpoint_updated) # is it good to update setpoint after we know the new value of each loss? self.flag_setpoint_updated = self.hyper_scheduler.update_setpoint( - self.epo_reg_loss_tr, self.epo_task_loss_tr) + self.epo_reg_loss_tr, self.epo_task_loss_tr + ) return flag diff --git a/domainlab/algos/trainers/zoo_trainer.py b/domainlab/algos/trainers/zoo_trainer.py index 9368cc7e7..92a965a65 100644 --- a/domainlab/algos/trainers/zoo_trainer.py +++ b/domainlab/algos/trainers/zoo_trainer.py @@ -3,11 +3,10 @@ """ from domainlab.algos.trainers.train_basic import TrainerBasic from domainlab.algos.trainers.train_dial import TrainerDIAL +from domainlab.algos.trainers.train_fbopt_b import TrainerFbOpt from domainlab.algos.trainers.train_hyper_scheduler import TrainerHyperScheduler from domainlab.algos.trainers.train_matchdg import TrainerMatchDG from domainlab.algos.trainers.train_mldg import TrainerMLDG -from domainlab.algos.trainers.train_hyper_scheduler import TrainerHyperScheduler -from domainlab.algos.trainers.train_fbopt_b import TrainerFbOpt class TrainerChainNodeGetter(object): diff --git a/domainlab/algos/zoo_algos.py b/domainlab/algos/zoo_algos.py index 95c4e0a6d..4b8240387 100644 --- a/domainlab/algos/zoo_algos.py +++ b/domainlab/algos/zoo_algos.py @@ -5,8 +5,8 @@ from domainlab.algos.builder_dann import NodeAlgoBuilderDANN from domainlab.algos.builder_diva import NodeAlgoBuilderDIVA from domainlab.algos.builder_erm import NodeAlgoBuilderERM -from domainlab.algos.builder_hduva import NodeAlgoBuilderHDUVA from domainlab.algos.builder_fbopt_dial import NodeAlgoBuilderFbOptDial +from domainlab.algos.builder_hduva import NodeAlgoBuilderHDUVA from domainlab.algos.builder_jigen1 import NodeAlgoBuilderJiGen from domainlab.utils.u_import import import_path diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index cc281950c..61a2ea97a 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -18,77 +18,117 @@ def mk_parser_main(): """ Args for command line definition """ - parser = argparse.ArgumentParser(description='DomainLab') + parser = argparse.ArgumentParser(description="DomainLab") - parser.add_argument('-c', "--config", default=None, - help="load YAML configuration", dest="config_file", - type=argparse.FileType(mode='r')) + parser.add_argument( + "-c", + "--config", + default=None, + help="load YAML configuration", + dest="config_file", + type=argparse.FileType(mode="r"), + ) - parser.add_argument('--lr', type=float, default=1e-4, - help='learning rate') + parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") - parser.add_argument('--gamma_reg', type=float, default=0.1, - help='weight of regularization loss') + parser.add_argument( + "--gamma_reg", type=float, default=0.1, help="weight of regularization loss" + ) - parser.add_argument('--es', type=int, default=1, - help='early stop steps') + parser.add_argument("--es", type=int, default=1, help="early stop steps") - parser.add_argument('--seed', type=int, default=0, - help='random seed (default: 0)') + parser.add_argument("--seed", type=int, default=0, help="random seed (default: 0)") - parser.add_argument('--nocu', action='store_true', default=False, - help='disables CUDA') + parser.add_argument( + "--nocu", action="store_true", default=False, help="disables CUDA" + ) - parser.add_argument('--device', type=str, default=None, - help='device name default None') + parser.add_argument( + "--device", type=str, default=None, help="device name default None" + ) - parser.add_argument('--gen', action='store_true', default=False, - help='save generated images') + parser.add_argument( + "--gen", action="store_true", default=False, help="save generated images" + ) - parser.add_argument('--keep_model', action='store_true', default=False, - help='do not delete model at the end of training') + parser.add_argument( + "--keep_model", + action="store_true", + default=False, + help="do not delete model at the end of training", + ) - parser.add_argument('--epos', default=2, type=int, - help='maximum number of epochs') + parser.add_argument("--epos", default=2, type=int, help="maximum number of epochs") - parser.add_argument('--epos_min', default=0, type=int, - help='maximum number of epochs') + parser.add_argument( + "--epos_min", default=0, type=int, help="maximum number of epochs" + ) - parser.add_argument('--epo_te', default=1, type=int, - help='test performance per {} epochs') + parser.add_argument( + "--epo_te", default=1, type=int, help="test performance per {} epochs" + ) - parser.add_argument('-w', '--warmup', type=int, default=100, - help='number of epochs for hyper-parameter warm-up. \ - Set to 0 to turn warmup off.') + parser.add_argument( + "-w", + "--warmup", + type=int, + default=100, + help="number of epochs for hyper-parameter warm-up. \ + Set to 0 to turn warmup off.", + ) - parser.add_argument('--debug', action='store_true', default=False) - parser.add_argument('--dmem', action='store_true', default=False) - parser.add_argument('--no_dump', action='store_true', default=False, - help='suppress saving the confusion matrix') + parser.add_argument("--debug", action="store_true", default=False) + parser.add_argument("--dmem", action="store_true", default=False) + parser.add_argument( + "--no_dump", + action="store_true", + default=False, + help="suppress saving the confusion matrix", + ) - parser.add_argument('--trainer', type=str, default=None, - help='specify which trainer to use') + parser.add_argument( + "--trainer", type=str, default=None, help="specify which trainer to use" + ) - parser.add_argument('--out', type=str, default="zoutput", - help='absolute directory to store outputs') + parser.add_argument( + "--out", type=str, default="zoutput", help="absolute directory to store outputs" + ) - parser.add_argument('--dpath', type=str, default="zdpath", - help="path for storing downloaded dataset") + parser.add_argument( + "--dpath", + type=str, + default="zdpath", + help="path for storing downloaded dataset", + ) - parser.add_argument('--tpath', type=str, default=None, - help="path for custom task, should implement \ - get_task function") + parser.add_argument( + "--tpath", + type=str, + default=None, + help="path for custom task, should implement \ + get_task function", + ) - parser.add_argument('--npath', type=str, default=None, - help="path of custom neural network for feature \ - extraction") + parser.add_argument( + "--npath", + type=str, + default=None, + help="path of custom neural network for feature \ + extraction", + ) - parser.add_argument('--npath_dom', type=str, default=None, - help="path of custom neural network for feature \ - extraction") + parser.add_argument( + "--npath_dom", + type=str, + default=None, + help="path of custom neural network for feature \ + extraction", + ) - parser.add_argument('--npath_argna2val', action='append', - help="specify new arguments and their value \ + parser.add_argument( + "--npath_argna2val", + action="append", + help="specify new arguments and their value \ e.g. '--npath_argna2val my_custom_arg_na \ --npath_argna2val xx/yy/zz.py', additional \ pairs can be appended", @@ -146,65 +186,103 @@ def mk_parser_main(): dest="bm_dir", help="Aggregates and plots partial data of a snakemake \ benchmark. Requires the benchmark config file. \ - Other arguments will be ignored.") - - parser.add_argument('--gen_plots', type=str, - default=None, dest="plot_data", - help="plots the data of a snakemake benchmark. " - "Requires the results.csv file" - "and an output file (specify by --outp_file," - "default is zoutput/benchmarks/shell_benchmark). " - "Other arguments will be ignored.") - - parser.add_argument('--outp_dir', type=str, - default='zoutput/benchmarks/shell_benchmark', dest="outp_dir", - help="outpus file for the plots when creating them" - "using --gen_plots. " - "Default is zoutput/benchmarks/shell_benchmark") - parser.add_argument('--param_idx', type=bool, - default=True, dest="param_idx", - help="True: parameter index is used in the " - "pots generated with --gen_plots." - "False: parameter name is used." - "Default is True.") - - parser.add_argument('--msel', choices=['val', 'loss_tr', 'last'], default="val", - help='model selection for early stop: val, loss_tr, recon, the \ + Other arguments will be ignored.", + ) + + parser.add_argument( + "--gen_plots", + type=str, + default=None, + dest="plot_data", + help="plots the data of a snakemake benchmark. " + "Requires the results.csv file" + "and an output file (specify by --outp_file," + "default is zoutput/benchmarks/shell_benchmark). " + "Other arguments will be ignored.", + ) + + parser.add_argument( + "--outp_dir", + type=str, + default="zoutput/benchmarks/shell_benchmark", + dest="outp_dir", + help="outpus file for the plots when creating them" + "using --gen_plots. " + "Default is zoutput/benchmarks/shell_benchmark", + ) + parser.add_argument( + "--param_idx", + type=bool, + default=True, + dest="param_idx", + help="True: parameter index is used in the " + "pots generated with --gen_plots." + "False: parameter name is used." + "Default is True.", + ) + + parser.add_argument( + "--msel", + choices=["val", "loss_tr", "last"], + default="val", + help="model selection for early stop: val, loss_tr, recon, the \ elbo and recon only make sense for vae models,\ - will be ignored by other methods') + will be ignored by other methods", + ) - parser.add_argument('--msel_tr_loss', choices=['reg', 'task'], default="task", - help='model selection for tr loss') + parser.add_argument( + "--msel_tr_loss", + choices=["reg", "task"], + default="task", + help="model selection for tr loss", + ) - parser.add_argument('--model', metavar="an", type=str, - default=None, - help='algorithm name') + parser.add_argument( + "--model", metavar="an", type=str, default=None, help="algorithm name" + ) - parser.add_argument('--acon', metavar="ac", type=str, default=None, - help='algorithm configuration name, (default None)') + parser.add_argument( + "--acon", + metavar="ac", + type=str, + default=None, + help="algorithm configuration name, (default None)", + ) - parser.add_argument('--task', metavar="ta", type=str, - help='task name') + parser.add_argument("--task", metavar="ta", type=str, help="task name") - arg_group_task = parser.add_argument_group('task args') + arg_group_task = parser.add_argument_group("task args") - arg_group_task.add_argument('--bs', type=int, default=100, - help='loader batch size for mixed domains') + arg_group_task.add_argument( + "--bs", type=int, default=100, help="loader batch size for mixed domains" + ) - arg_group_task.add_argument('--split', type=float, default=0, - help='proportion of training, a value between \ - 0 and 1, 0 means no train-validation split') + arg_group_task.add_argument( + "--split", + type=float, + default=0, + help="proportion of training, a value between \ + 0 and 1, 0 means no train-validation split", + ) - arg_group_task.add_argument('--te_d', nargs='*', default=None, - help='test domain names separated by single space, \ - will be parsed to be list of strings') + arg_group_task.add_argument( + "--te_d", + nargs="*", + default=None, + help="test domain names separated by single space, \ + will be parsed to be list of strings", + ) - arg_group_task.add_argument('--tr_d', nargs='*', default=None, - help='training domain names separated by \ + arg_group_task.add_argument( + "--tr_d", + nargs="*", + default=None, + help="training domain names separated by \ single space, will be parsed to be list of \ strings; if not provided then all available \ domains that are not assigned to \ - the test set will be used as training domains') + the test set will be used as training domains", + ) arg_group_task.add_argument( "--san_check", @@ -233,7 +311,7 @@ def mk_parser_main(): arg_group_jigen = add_args2parser_jigen(arg_group_jigen) args_group_dial = parser.add_argument_group("dial") args_group_dial = add_args2parser_dial(args_group_dial) - args_group_fbopt = parser.add_argument_group('fbopt') + args_group_fbopt = parser.add_argument_group("fbopt") args_group_fbopt = add_args2parser_fbopt(args_group_fbopt) return parser diff --git a/domainlab/exp_protocol/run_experiment.py b/domainlab/exp_protocol/run_experiment.py index 19d2535ea..f45bc9498 100644 --- a/domainlab/exp_protocol/run_experiment.py +++ b/domainlab/exp_protocol/run_experiment.py @@ -143,10 +143,10 @@ def run_experiment( gpu_ind = param_index % num_gpus args.device = str(gpu_ind) - logger.info('*** begin args') + logger.info("*** begin args") for k, v in vars(args).items(): - logger.info(f'{k} : {v}') - logger.info('*** end args') + logger.info(f"{k} : {v}") + logger.info("*** end args") if torch.cuda.is_available(): torch.cuda.init() diff --git a/domainlab/models/a_model.py b/domainlab/models/a_model.py index f2763f2db..ca639daa2 100644 --- a/domainlab/models/a_model.py +++ b/domainlab/models/a_model.py @@ -13,6 +13,7 @@ class AModel(nn.Module, metaclass=abc.ABCMeta): """ operations that all models (classification, segmentation, seq2seq) """ + def set_params(self, dict_params): """ set diff --git a/domainlab/models/model_diva.py b/domainlab/models/model_diva.py index fb55e9fe1..eb41a84f0 100644 --- a/domainlab/models/model_diva.py +++ b/domainlab/models/model_diva.py @@ -9,7 +9,9 @@ from domainlab.utils.utils_class import store_args -def mk_diva(parent_class=VAEXYDClassif, str_diva_multiplier_type="default"): # FIXME: should not be default +def mk_diva( + parent_class=VAEXYDClassif, str_diva_multiplier_type="default" +): # FIXME: should not be default """ Instantiate a domain invariant variational autoencoder (DIVA) with arbitrary task loss. @@ -57,11 +59,22 @@ class ModelDIVA(parent_class): """ @store_args - def __init__(self, chain_node_builder, - zd_dim, zy_dim, zx_dim, - list_str_y, list_d_tr, - gamma_d, gamma_y, - beta_d, beta_x, beta_y, mu_recon=1.0): + def __init__( + self, + chain_node_builder, + zd_dim, + zy_dim, + zx_dim, + list_str_y, + list_d_tr, + gamma_d, + gamma_y, + beta_d, + beta_x, + beta_y, + mu_recon=1.0, + ): + # pylint: disable=too-many-arguments, unused-argument """ gamma: classification loss coefficient """ @@ -118,12 +131,14 @@ def dict_multiplier(self): """ list of multipliers name, which correspond to cal_reg_loss """ - return {"mu_recon": self.mu_recon, - "beta_d": self.beta_d, - "beta_x": self.beta_x, - "beta_y": self.beta_y, - "gamma_d": self.gamma_d} - + return { + "mu_recon": self.mu_recon, + "beta_d": self.beta_d, + "beta_x": self.beta_x, + "beta_y": self.beta_y, + "gamma_d": self.gamma_d, + } + def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): q_zd, zd_q, q_zx, zx_q, q_zy, zy_q = self.encoder(tensor_x) logit_d = self.net_classif_d(zd_q) @@ -157,8 +172,13 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): _, d_target = tensor_d.max(dim=1) lc_d = F.cross_entropy(logit_d, d_target, reduction=g_str_cross_entropy_agg) - return [loss_recon_x, zd_p_minus_zd_q, zx_p_minus_zx_q, zy_p_minus_zy_q, lc_d], \ - [self.mu_recon, -self.beta_d, -self.beta_x, -self.beta_y, self.gamma_d] + return [ + loss_recon_x, + zd_p_minus_zd_q, + zx_p_minus_zx_q, + zy_p_minus_zy_q, + lc_d, + ], [self.mu_recon, -self.beta_d, -self.beta_x, -self.beta_y, self.gamma_d] class ModelDIVAGammadRecon(ModelDIVA): def hyper_update(self, epoch, fun_scheduler): @@ -174,7 +194,6 @@ def hyper_update(self, epoch, fun_scheduler): self.gamma_d = dict_rst["gamma_d"] self.mu_recon = dict_rst["mu_recon"] - def hyper_init(self, functor_scheduler, trainer=None): """ initiate a scheduler object via class name and things inside this model @@ -187,23 +206,32 @@ def hyper_init(self, functor_scheduler, trainer=None): beta_y=self.beta_y, beta_x=self.beta_x, gamma_d=self.gamma_d, - mu_recon=self.mu_recon + mu_recon=self.mu_recon, ) class ModelDIVAGammadReconPerPixel(ModelDIVAGammadRecon): def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): - [loss_recon_x, zd_p_minus_zd_q, zx_p_minus_zx_q, zy_p_minus_zy_q, lc_d], [mu_recon, minus_beta_d, minus_beta_x, minus_beta_y, gamma_d] = super().cal_reg_loss(tensor_x, tensor_y, tensor_d, others) - - return [torch.div(loss_recon_x, - tensor_x.shape[2] * tensor_x.shape[3]), - zd_p_minus_zd_q, zx_p_minus_zx_q, zy_p_minus_zy_q, lc_d], \ - [mu_recon, minus_beta_d, minus_beta_x, - minus_beta_y, gamma_d] + [loss_recon_x, zd_p_minus_zd_q, zx_p_minus_zx_q, zy_p_minus_zy_q, lc_d], [ + mu_recon, + minus_beta_d, + minus_beta_x, + minus_beta_y, + gamma_d, + ] = super().cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + + return [ + torch.div(loss_recon_x, tensor_x.shape[2] * tensor_x.shape[3]), + zd_p_minus_zd_q, + zx_p_minus_zx_q, + zy_p_minus_zy_q, + lc_d, + ], [mu_recon, minus_beta_d, minus_beta_x, minus_beta_y, gamma_d] class ModelDIVAGammad(ModelDIVA): """ only adjust gammad and beta """ + def hyper_update(self, epoch, fun_scheduler): """hyper_update. @@ -234,6 +262,7 @@ class ModelDIVADefault(ModelDIVA): """ mock """ + if str_diva_multiplier_type == "gammad_recon": return ModelDIVAGammadRecon if str_diva_multiplier_type == "gammad_recon_per_pixel": @@ -244,4 +273,5 @@ class ModelDIVADefault(ModelDIVA): return ModelDIVADefault raise RuntimeError( "not support argument candiates for str_diva_multiplier_type: \ - allowed: default, gammad_recon, gammad_recon_per_pixel, gammad") + allowed: default, gammad_recon, gammad_recon_per_pixel, gammad" + ) diff --git a/domainlab/models/model_hduva.py b/domainlab/models/model_hduva.py index fa33a315e..a86c4e405 100644 --- a/domainlab/models/model_hduva.py +++ b/domainlab/models/model_hduva.py @@ -85,24 +85,30 @@ def hyper_init(self, functor_scheduler, trainer=None): beta_d=self.beta_d, beta_y=self.beta_y, beta_x=self.beta_x, - beta_t=self.beta_t) + beta_t=self.beta_t, + ) @store_args - def __init__(self, chain_node_builder, - zy_dim, zd_dim, - list_str_y, - gamma_d, gamma_y, - beta_d, beta_x, beta_y, - beta_t, - device, - zx_dim=0, - topic_dim=3, - mu_recon=1.0): - """ - """ - super().__init__(chain_node_builder, - zd_dim, zy_dim, zx_dim, - list_str_y) + def __init__( + self, + chain_node_builder, + zy_dim, + zd_dim, + list_str_y, + gamma_d, + gamma_y, + beta_d, + beta_x, + beta_y, + beta_t, + device, + zx_dim=0, + topic_dim=3, + mu_recon=1.0, + ): + # pylint: disable=too-many-arguments, unused-argument + """ """ + super().__init__(chain_node_builder, zd_dim, zy_dim, zx_dim, list_str_y) # topic to zd follows Gaussian distribution self.add_module( "net_p_zd", @@ -189,8 +195,13 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d=None, others=None): # reconstruction z_concat = self.decoder.concat_ytdx(zy_q, topic_q, zd_q, zx_q) loss_recon_x, _, _ = self.decoder(z_concat, tensor_x) - return [loss_recon_x, zx_p_minus_q, zy_p_minus_zy_q, zd_p_minus_q, topic_p_minus_q], \ - [self.mu_recon, -self.beta_x, -self.beta_y, -self.beta_d, -self.beta_t] + return [ + loss_recon_x, + zx_p_minus_q, + zy_p_minus_zy_q, + zd_p_minus_q, + topic_p_minus_q, + ], [self.mu_recon, -self.beta_x, -self.beta_y, -self.beta_d, -self.beta_t] @property def list_str_multiplier_na(self): @@ -204,11 +215,13 @@ def dict_multiplier(self): """ dictionary of multipliers name """ - return {"mu_recon": self.mu_recon, - "beta_d": self.beta_d, - "beta_x": self.beta_x, - "beta_y": self.beta_y, - "beta_t": self.beta_t} + return { + "mu_recon": self.mu_recon, + "beta_d": self.beta_d, + "beta_x": self.beta_x, + "beta_y": self.beta_y, + "beta_t": self.beta_t, + } def extract_semantic_feat(self, tensor_x): """ diff --git a/domainlab/tasks/b_task.py b/domainlab/tasks/b_task.py index fcbbe473e..e64222984 100644 --- a/domainlab/tasks/b_task.py +++ b/domainlab/tasks/b_task.py @@ -54,7 +54,9 @@ def init_business(self, args, trainer=None): self.dict_dset_val.update({na_domain: ddset_val}) ddset_mix = ConcatDataset(tuple(self.dict_dset_tr.values())) self._loader_tr = mk_loader(ddset_mix, args.bs) - self._loader_tr_no_drop = mk_loader(ddset_mix, args.bs, drop_last=False, shuffle=False) + self._loader_tr_no_drop = mk_loader( + ddset_mix, args.bs, drop_last=False, shuffle=False + ) ddset_mix_val = ConcatDataset(tuple(self.dict_dset_val.values())) self._loader_val = mk_loader( diff --git a/domainlab/utils/generate_fbopt_phase_portrait.py b/domainlab/utils/generate_fbopt_phase_portrait.py index b58d2ff83..2b9baae07 100644 --- a/domainlab/utils/generate_fbopt_phase_portrait.py +++ b/domainlab/utils/generate_fbopt_phase_portrait.py @@ -1,24 +1,36 @@ +""" +This file is used for generating phase portrait from tensorboard event files. +""" + +import argparse import glob import os -import numpy as np -import argparse import matplotlib.pyplot as plt -from tensorboard.backend.event_processing.event_accumulator \ - import EventAccumulator - - -def get_xy_from_event_file(event_file, plot1, plot2=None, - tf_size_guidance=None, - sanity_check=False, verbose=True): +import numpy as np +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + + +# pylint: disable=too-many-arguments +def get_xy_from_event_file( + event_file, + plot1, + plot2=None, + tf_size_guidance=None, + sanity_check=False, + verbose=True, +): + """ + extract x and y values from a tensorboard event file + """ if tf_size_guidance is None: # settings for which/how much data is loaded from the # tensorboard event files tf_size_guidance = { - 'compressedHistograms': 0, - 'images': 0, - 'scalars': 1e10, # keep unlimited number - 'histograms': 0 + "compressedHistograms": 0, + "images": 0, + "scalars": 1e10, # keep unlimited number + "histograms": 0, } # load event file event = EventAccumulator(event_file, tf_size_guidance) @@ -26,7 +38,7 @@ def get_xy_from_event_file(event_file, plot1, plot2=None, # print names of available plots if verbose: print(f"Event file {event_file} -- available plots:") - print(event.Tags()['scalars']) + print(event.Tags()["scalars"]) if plot2: # extract the plot2 values (e.g., reg/dyn0) y_event = event.Scalars(plot2) @@ -47,9 +59,17 @@ def get_xy_from_event_file(event_file, plot1, plot2=None, return x, y -def phase_portrait_combined(event_files, colors, plot1, plot2, - legend1=None, legend2=None, plot_len=None, - output_dir="."): +# pylint: disable=too-many-arguments, too-many-locals, redefined-outer-name, unused-argument +def phase_portrait_combined( + event_files, + colors, + plot1, + plot2, + legend1=None, + legend2=None, + plot_len=None, + output_dir=".", +): """ combined phase portait for multiple (at least one) Tensorboard event files in the same plot @@ -57,8 +77,7 @@ def phase_portrait_combined(event_files, colors, plot1, plot2, plt.figure() for event_i in range(len(event_files)): - x, y = get_xy_from_event_file(event_files[event_i], - plot1=plot1, plot2=plot2) + x, y = get_xy_from_event_file(event_files[event_i], plot1=plot1, plot2=plot2) assert len(x) == len(y) if plot_len is None: @@ -68,25 +87,35 @@ def phase_portrait_combined(event_files, colors, plot1, plot2, head_w_glob = min((max(x) - min(x)) / 100.0, (max(y) - min(y)) / 100.0) for i in range(plot_len - 1): - xy_dist = np.sqrt((x[i + 1] - x[i])**2 + (y[i + 1] - y[i])**2) + xy_dist = np.sqrt((x[i + 1] - x[i]) ** 2 + (y[i + 1] - y[i]) ** 2) head_l = xy_dist / 30.0 head_w = min(head_l, head_w_glob) - plt.arrow(x[i], y[i], (x[i + 1] - x[i]), (y[i + 1] - y[i]), - head_width=head_w, head_length=head_l, - length_includes_head=True, - fc=colors[event_i], ec=colors[event_i], alpha=0.8) + plt.arrow( + x[i], + y[i], + (x[i + 1] - x[i]), + (y[i + 1] - y[i]), + head_width=head_w, + head_length=head_l, + length_includes_head=True, + fc=colors[event_i], + ec=colors[event_i], + alpha=0.8, + ) # the combination of head_width and head_length make the arrow # more visible. # length_includes_head=False makes the arrow stick too far out # beyond of the point, which let; so, True is used. - colors = ['red', 'green', 'blue', 'yellow', 'purple'] - plt.plot(x[0], y[0], 'ko') + colors = ["red", "green", "blue", "yellow", "purple"] + plt.plot(x[0], y[0], "ko") list_color = [colors[i % len(colors)] for i, h in enumerate(x)] plt.scatter(x, y, s=1, c=np.array(list_color)) - if legend1 is None: legend1=plot1 - if legend2 is None: legend2=plot2 + if legend1 is None: + legend1 = plot1 + if legend2 is None: + legend2 = plot2 plt.xlabel(legend1) plt.ylabel(legend2) plt.title("phase portrait") @@ -94,27 +123,37 @@ def phase_portrait_combined(event_files, colors, plot1, plot2, if not os.path.exists(output_dir): os.makedirs(output_dir) legend22 = legend2.split(os.sep)[-1] - plt.savefig(os.path.join(output_dir, - f'phase_portrait_combined_{legend22}.png'), dpi=300) - - -def two_curves_combined(event_files, colors, plot1, plot2, - legend1=None, legend2=None, output_dir=".", title=None): + plt.savefig( + os.path.join(output_dir, f"phase_portrait_combined_{legend22}.png"), dpi=300 + ) + + +def two_curves_combined( + event_files, + colors, + plot1, + plot2, + legend1=None, + legend2=None, + output_dir=".", + title=None, +): """ FIXME: colors parameter is not used """ plt.figure() for event_i in range(len(event_files)): - x, y = get_xy_from_event_file(event_files[event_i], - plot1=plot1, plot2=plot2) + x, y = get_xy_from_event_file(event_files[event_i], plot1=plot1, plot2=plot2) plt.plot(x, color="blue") plt.plot(y, color="red") plt.xlabel("epoch") # plt.ylabel("loss") if title is not None: plt.title(title) - if legend1 is None: legend1=plot1 - if legend2 is None: legend2=plot2 + if legend1 is None: + legend1 = plot1 + if legend2 is None: + legend2 = plot2 plt.legend([legend1, legend2]) legend11 = legend1.replace(os.sep, "_") @@ -122,8 +161,9 @@ def two_curves_combined(event_files, colors, plot1, plot2, if not os.path.exists(output_dir): os.makedirs(output_dir) - plt.savefig(os.path.join(output_dir, - f'timecourse_{legend11}_{legend22}.png'), dpi=300) + plt.savefig( + os.path.join(output_dir, f"timecourse_{legend11}_{legend22}.png"), dpi=300 + ) def plot_single_curve(event_files, colors, plot1, legend1=None, output_dir="."): @@ -135,7 +175,8 @@ def plot_single_curve(event_files, colors, plot1, legend1=None, output_dir="."): x, _ = get_xy_from_event_file(event_files[event_i], plot1=plot1) plt.plot(x) plt.xlabel("time") - if legend1 is None: legend1=plot1 + if legend1 is None: + legend1 = plot1 plt.ylabel(legend1) # plt.title("timecourse") @@ -143,47 +184,68 @@ def plot_single_curve(event_files, colors, plot1, legend1=None, output_dir="."): if not os.path.exists(output_dir): os.makedirs(output_dir) - plt.savefig(os.path.join(output_dir, - f'timecourse_{legend11}.png'), dpi=300) + plt.savefig(os.path.join(output_dir, f"timecourse_{legend11}.png"), dpi=300) if __name__ == "__main__": - parser = argparse.ArgumentParser(description='plot') - parser.add_argument('-plot1', "--plot1", default=None, type=str) - parser.add_argument('-plot2', "--plot2", default=None, type=str) - parser.add_argument('-legend1', "--legend1", default=None, type=str) - parser.add_argument('-legend2', "--legend2", default=None, type=str) - parser.add_argument('-plot_len', "--plot_len", default=None, type=int) - parser.add_argument('-title', "--title", default=None, type=str) - parser.add_argument('--output_dir', default='.', type=str) - parser.add_argument('--phase_portrait', action='store_true', - help="if True plots a phase portrait,\ - otherwise a curve (default)") + parser = argparse.ArgumentParser(description="plot") + parser.add_argument("-plot1", "--plot1", default=None, type=str) + parser.add_argument("-plot2", "--plot2", default=None, type=str) + parser.add_argument("-legend1", "--legend1", default=None, type=str) + parser.add_argument("-legend2", "--legend2", default=None, type=str) + parser.add_argument("-plot_len", "--plot_len", default=None, type=int) + parser.add_argument("-title", "--title", default=None, type=str) + parser.add_argument("--output_dir", default=".", type=str) + parser.add_argument( + "--phase_portrait", + action="store_true", + help="if True plots a phase portrait,\ + otherwise a curve (default)", + ) args = parser.parse_args() # get event files from all available runs event_files = glob.glob("runs/*/events*") - print("Using the following tensorboard event files:\n{}".format( - "\n".join(event_files))) + print( + "Using the following tensorboard event files:\n{}".format( + "\n".join(event_files) + ) + ) # Different colors for the different runs - cmap = plt.get_cmap('tab10') # Choose a colormap + cmap = plt.get_cmap("tab10") # Choose a colormap colors = [cmap(i) for i in range(len(event_files))] if args.phase_portrait: - phase_portrait_combined(event_files, colors, - plot1=args.plot1, plot2=args.plot2, - legend1=args.legend1, legend2=args.legend2, - plot_len=args.plot_len, output_dir=args.output_dir) + phase_portrait_combined( + event_files, + colors, + plot1=args.plot1, + plot2=args.plot2, + legend1=args.legend1, + legend2=args.legend2, + plot_len=args.plot_len, + output_dir=args.output_dir, + ) else: if args.plot2: # two curves per plot - two_curves_combined(event_files, colors, - plot1=args.plot1, plot2=args.plot2, - legend1=args.legend1, legend2=args.legend2, - output_dir=args.output_dir, title=args.title) + two_curves_combined( + event_files, + colors, + plot1=args.plot1, + plot2=args.plot2, + legend1=args.legend1, + legend2=args.legend2, + output_dir=args.output_dir, + title=args.title, + ) else: # one curve per plot - plot_single_curve(event_files, colors, - plot1=args.plot1, legend1=args.legend1, - output_dir=args.output_dir) + plot_single_curve( + event_files, + colors, + plot1=args.plot1, + legend1=args.legend1, + output_dir=args.output_dir, + ) diff --git a/examples/benchmark/mnist_dann_fbopt.yaml b/examples/benchmark/mnist_dann_fbopt.yaml index 9417fbdd3..8bdbe444c 100644 --- a/examples/benchmark/mnist_dann_fbopt.yaml +++ b/examples/benchmark/mnist_dann_fbopt.yaml @@ -44,7 +44,7 @@ Shared params: mu_init: min: 0.000001 - max: 0.00001 + max: 0.00001 num: 2 distribution: uniform diff --git a/examples/benchmark/mnist_diva_fbopt_alone.yaml b/examples/benchmark/mnist_diva_fbopt_alone.yaml index 7ab074716..c483b0e68 100644 --- a/examples/benchmark/mnist_diva_fbopt_alone.yaml +++ b/examples/benchmark/mnist_diva_fbopt_alone.yaml @@ -65,7 +65,7 @@ Shared params: num: 3 distribution: loguniform - mu_clip: + mu_clip: distribution: categorical datatype: float values: diff --git a/examples/benchmark/mnist_diva_fbopt_and_baselines.yaml b/examples/benchmark/mnist_diva_fbopt_and_baselines.yaml index 8cd4008fb..b687b69f4 100644 --- a/examples/benchmark/mnist_diva_fbopt_and_baselines.yaml +++ b/examples/benchmark/mnist_diva_fbopt_and_baselines.yaml @@ -71,7 +71,7 @@ Shared params: num: 3 distribution: loguniform - mu_clip: + mu_clip: distribution: categorical datatype: float values: diff --git a/examples/benchmark/mnist_jigen_fbopt_and_others.yaml b/examples/benchmark/mnist_jigen_fbopt_and_others.yaml index aa03a3040..bd4857610 100644 --- a/examples/benchmark/mnist_jigen_fbopt_and_others.yaml +++ b/examples/benchmark/mnist_jigen_fbopt_and_others.yaml @@ -63,7 +63,7 @@ jigen_feedback: shared: - k_i_gain - mu_clip - + jigen_feedforward: model: jigen trainer: hyperscheduler diff --git a/examples/benchmark/pacs_diva_fbopt_alone_es1.yaml b/examples/benchmark/pacs_diva_fbopt_alone_es1.yaml index 0d5bfd072..8c87ddafe 100644 --- a/examples/benchmark/pacs_diva_fbopt_alone_es1.yaml +++ b/examples/benchmark/pacs_diva_fbopt_alone_es1.yaml @@ -87,7 +87,7 @@ Shared params: -# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small +# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small diva_fbopt_full: model: diva diff --git a/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki.yaml b/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki.yaml index 129e392cf..22b5ee4d8 100644 --- a/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki.yaml +++ b/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki.yaml @@ -93,7 +93,7 @@ Shared params: -# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small +# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small diva_fbopt_full: model: diva diff --git a/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki_output_ma_9.yaml b/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki_output_ma_9.yaml index 8e21dbfe7..36fd10554 100644 --- a/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki_output_ma_9.yaml +++ b/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki_output_ma_9.yaml @@ -94,7 +94,7 @@ Shared params: -# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small +# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small diva_fbopt_full: model: diva diff --git a/examples/benchmark/pacs_diva_fbopt_alone_fixed.yaml b/examples/benchmark/pacs_diva_fbopt_alone_fixed.yaml index 730696f69..e2a78230a 100644 --- a/examples/benchmark/pacs_diva_fbopt_alone_fixed.yaml +++ b/examples/benchmark/pacs_diva_fbopt_alone_fixed.yaml @@ -86,7 +86,7 @@ Shared params: -# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small +# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small diva_fbopt_full: model: diva diff --git a/examples/benchmark/pacs_hduva_baselines.yaml b/examples/benchmark/pacs_hduva_baselines.yaml index e40759681..cbdb704eb 100644 --- a/examples/benchmark/pacs_hduva_baselines.yaml +++ b/examples/benchmark/pacs_hduva_baselines.yaml @@ -94,7 +94,7 @@ Shared params: -# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small +# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small hduva_beta_warmup: model: hduva diff --git a/examples/benchmark/pacs_hduva_fbopt_alone_es1_autoki_aug.yaml b/examples/benchmark/pacs_hduva_fbopt_alone_es1_autoki_aug.yaml index 1422a302c..d773cb25b 100644 --- a/examples/benchmark/pacs_hduva_fbopt_alone_es1_autoki_aug.yaml +++ b/examples/benchmark/pacs_hduva_fbopt_alone_es1_autoki_aug.yaml @@ -94,7 +94,7 @@ Shared params: -# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small +# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small diva_fbopt_full: model: hduva diff --git a/examples/benchmark/pacs_hduva_matchdg.yaml b/examples/benchmark/pacs_hduva_matchdg.yaml index 1b12d0143..f8c99d6d3 100644 --- a/examples/benchmark/pacs_hduva_matchdg.yaml +++ b/examples/benchmark/pacs_hduva_matchdg.yaml @@ -102,7 +102,7 @@ Shared params: -# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small +# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small match_duva: model: matchhduva diff --git a/examples/benchmark/pacs_jigen_baslines4fbopt.yaml b/examples/benchmark/pacs_jigen_baslines4fbopt.yaml index b28be8632..8c4d99d3d 100644 --- a/examples/benchmark/pacs_jigen_baslines4fbopt.yaml +++ b/examples/benchmark/pacs_jigen_baslines4fbopt.yaml @@ -55,7 +55,7 @@ Shared params: # Test fbopt with different hyperparameter configurations - + jigen_feedforward: model: jigen trainer: hyperscheduler diff --git a/examples/benchmark/pacs_jigen_fbopt_alone_autoki.yaml b/examples/benchmark/pacs_jigen_fbopt_alone_autoki.yaml index 0405e6731..3c70d07b6 100644 --- a/examples/benchmark/pacs_jigen_fbopt_alone_autoki.yaml +++ b/examples/benchmark/pacs_jigen_fbopt_alone_autoki.yaml @@ -27,7 +27,7 @@ domainlab_args: zx_dim: 0 zy_dim: 64 zd_dim: 64 - pperm: 0.1 + pperm: 0.1 # pperm correspond to 1-bias_wholeimage in https://github.com/fmcarlucci/JigenDG diff --git a/examples/benchmark/pacs_jigen_fbopt_and_baselines.yaml b/examples/benchmark/pacs_jigen_fbopt_and_baselines.yaml index 5fac934ce..1421913b3 100644 --- a/examples/benchmark/pacs_jigen_fbopt_and_baselines.yaml +++ b/examples/benchmark/pacs_jigen_fbopt_and_baselines.yaml @@ -73,7 +73,7 @@ jigen_feedback: shared: - k_i_gain - mu_clip - + jigen_feedforward: model: jigen trainer: hyperscheduler diff --git a/examples/benchmark/pacs_jigen_fbopt_and_baselines_aug.yaml b/examples/benchmark/pacs_jigen_fbopt_and_baselines_aug.yaml index 4fb01259b..3b0f8dba6 100644 --- a/examples/benchmark/pacs_jigen_fbopt_and_baselines_aug.yaml +++ b/examples/benchmark/pacs_jigen_fbopt_and_baselines_aug.yaml @@ -86,7 +86,7 @@ jigen_feedback: - k_i_gain_ratio - mu_clip - lr - + jigen_feedforward: model: jigen trainer: hyperscheduler diff --git a/fbopt_mnist_diva_pixel.sh b/fbopt_mnist_diva_pixel.sh index 6440f90c8..bac129db9 100644 --- a/fbopt_mnist_diva_pixel.sh +++ b/fbopt_mnist_diva_pixel.sh @@ -1,5 +1,5 @@ #!/bin/bash -# export CUDA_VISIBLE_DEVICES="" +# export CUDA_VISIBLE_DEVICES="" # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # pytest -s tests/test_fbopt.py diff --git a/run_benchmark_slurm.sh b/run_benchmark_slurm.sh index 8f6b9f66e..91e06fa62 100755 --- a/run_benchmark_slurm.sh +++ b/run_benchmark_slurm.sh @@ -35,5 +35,4 @@ echo "verbose log: $logfile" rm -f -R .snakemake -snakemake --profile "examples/yaml/slurm" --keep-going --keep-incomplete --notemp --cores 3 -s "domainlab/exp_protocol/benchmark.smk" --configfile "$CONFIGFILE" 2>&1 | tee "$logfile" - +snakemake --profile "examples/yaml/slurm" --keep-going --keep-incomplete --notemp --cores 3 -s "domainlab/exp_protocol/benchmark.smk" --configfile "$CONFIGFILE" 2>&1 | tee "$logfile" diff --git a/run_fbopt_diva.sh b/run_fbopt_diva.sh index 02cfbd6d9..dc48bce9b 100644 --- a/run_fbopt_diva.sh +++ b/run_fbopt_diva.sh @@ -1,5 +1,5 @@ #!/bin/bash -# export CUDA_VISIBLE_DEVICES="" +# export CUDA_VISIBLE_DEVICES="" # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # pytest -s tests/test_fbopt.py diff --git a/run_fbopt_diva_cpu.sh b/run_fbopt_diva_cpu.sh index 2fa5665f5..59d0c592a 100644 --- a/run_fbopt_diva_cpu.sh +++ b/run_fbopt_diva_cpu.sh @@ -1,5 +1,5 @@ #!/bin/bash -export CUDA_VISIBLE_DEVICES="" +export CUDA_VISIBLE_DEVICES="" # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # pytest -s tests/test_fbopt.py diff --git a/run_fbopt_mnist.sh b/run_fbopt_mnist.sh index 7f0f7fcf9..2e3edc424 100644 --- a/run_fbopt_mnist.sh +++ b/run_fbopt_mnist.sh @@ -1,5 +1,5 @@ #!/bin/bash -# export CUDA_VISIBLE_DEVICES="" +# export CUDA_VISIBLE_DEVICES="" # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # pytest -s tests/test_fbopt.py diff --git a/run_fbopt_mnist_diva.sh b/run_fbopt_mnist_diva.sh index 85caf8695..fd5c2b8cf 100644 --- a/run_fbopt_mnist_diva.sh +++ b/run_fbopt_mnist_diva.sh @@ -1,5 +1,5 @@ #!/bin/bash -# export CUDA_VISIBLE_DEVICES="" +# export CUDA_VISIBLE_DEVICES="" # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # pytest -s tests/test_fbopt.py diff --git a/run_fbopt_mnist_diva_autoki.sh b/run_fbopt_mnist_diva_autoki.sh index 6423ff283..64c19e102 100644 --- a/run_fbopt_mnist_diva_autoki.sh +++ b/run_fbopt_mnist_diva_autoki.sh @@ -1,5 +1,5 @@ #!/bin/bash -# export CUDA_VISIBLE_DEVICES="" +# export CUDA_VISIBLE_DEVICES="" # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # pytest -s tests/test_fbopt.py diff --git a/run_fbopt_mnist_feedforward.sh b/run_fbopt_mnist_feedforward.sh index b7b1139a7..b04819c61 100644 --- a/run_fbopt_mnist_feedforward.sh +++ b/run_fbopt_mnist_feedforward.sh @@ -1,5 +1,5 @@ #!/bin/bash -# export CUDA_VISIBLE_DEVICES="" +# export CUDA_VISIBLE_DEVICES="" # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # pytest -s tests/test_fbopt.py diff --git a/run_fbopt_mnist_jigen_autoki.sh b/run_fbopt_mnist_jigen_autoki.sh index 24d7eb059..8b346e011 100644 --- a/run_fbopt_mnist_jigen_autoki.sh +++ b/run_fbopt_mnist_jigen_autoki.sh @@ -1,5 +1,5 @@ #!/bin/bash -# export CUDA_VISIBLE_DEVICES="" +# export CUDA_VISIBLE_DEVICES="" # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # pytest -s tests/test_fbopt.py diff --git a/run_fbopt_small_pacs.sh b/run_fbopt_small_pacs.sh index 3583837bc..fc3ab6bc7 100644 --- a/run_fbopt_small_pacs.sh +++ b/run_fbopt_small_pacs.sh @@ -1,5 +1,5 @@ #!/bin/bash -# export CUDA_VISIBLE_DEVICES="" +# export CUDA_VISIBLE_DEVICES="" # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # pytest -s tests/test_fbopt.py diff --git a/run_mnist_jigen.sh b/run_mnist_jigen.sh index 2b06392a4..0bc854c5e 100644 --- a/run_mnist_jigen.sh +++ b/run_mnist_jigen.sh @@ -1,8 +1,7 @@ #!/bin/bash -# export CUDA_VISIBLE_DEVICES="" +# export CUDA_VISIBLE_DEVICES="" # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # pytest -s tests/test_fbopt.py -python main_out.py --te_d=caltech --task=mini_vlcs --bs=16 --model=jigen --trainer=fbopt --nname=alexnet --epos=200 --es=200 --mu_init=1.0 --coeff_ma_output=0 --coeff_ma_setpoint=0 --coeff_ma_output=0 - +python main_out.py --te_d=caltech --task=mini_vlcs --bs=16 --model=jigen --trainer=fbopt --nname=alexnet --epos=200 --es=200 --mu_init=1.0 --coeff_ma_output=0 --coeff_ma_setpoint=0 --coeff_ma_output=0 diff --git a/run_pacs_diva_fbopt.sh b/run_pacs_diva_fbopt.sh index 0157b21a8..74d1f0cd3 100644 --- a/run_pacs_diva_fbopt.sh +++ b/run_pacs_diva_fbopt.sh @@ -1,5 +1,5 @@ #!/bin/bash -# export CUDA_VISIBLE_DEVICES="" +# export CUDA_VISIBLE_DEVICES="" # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # pytest -s tests/test_fbopt.py diff --git a/run_pacs_jigen_fbopt.sh b/run_pacs_jigen_fbopt.sh index 70be6b7fe..99663ee61 100644 --- a/run_pacs_jigen_fbopt.sh +++ b/run_pacs_jigen_fbopt.sh @@ -1,5 +1,5 @@ #!/bin/bash -# export CUDA_VISIBLE_DEVICES="" +# export CUDA_VISIBLE_DEVICES="" # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # pytest -s tests/test_fbopt.py diff --git a/script_jigen_plot.sh b/script_jigen_plot.sh index 6d770a266..5c47a68f8 100755 --- a/script_jigen_plot.sh +++ b/script_jigen_plot.sh @@ -1,4 +1,4 @@ -python domainlab/utils/generate_fbopt_phase_portrait.py --plot2="lossrd/dyn_alpha" --plot1="loss_task/ell" --legend2="regularization loss jigen" --legend1="classification loss" --output_dir="." --phase_portrait +python domainlab/utils/generate_fbopt_phase_portrait.py --plot2="lossrd/dyn_alpha" --plot1="loss_task/ell" --legend2="regularization loss jigen" --legend1="classification loss" --output_dir="." --phase_portrait python domainlab/utils/generate_fbopt_phase_portrait.py --plot1="lossrs/setpoint_alpha" --plot2="lossrd/dyn_alpha" --legend2="regularization loss jigen" --legend1="setpoint" --output_dir="." diff --git a/test_fbopt_dial.sh b/test_fbopt_dial.sh index 9583fa080..4bf0c669b 100644 --- a/test_fbopt_dial.sh +++ b/test_fbopt_dial.sh @@ -1,2 +1,2 @@ -export CUDA_VISIBLE_DEVICES="" +export CUDA_VISIBLE_DEVICES="" python main_out.py --te_d=caltech --task=mini_vlcs --bs=16 --model=fboptdial --trainer=dial --nname=alexnet --nname_dom=alexnet --gamma_y=1e6 --gamma_d=1e6 diff --git a/test_match_duva.sh b/test_match_duva.sh index c67c93ea2..9f3e9951e 100644 --- a/test_match_duva.sh +++ b/test_match_duva.sh @@ -2,5 +2,3 @@ python main_out.py --te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --debug --bs=2 -- --epochs_ctr=3 --epos=6 --nname=conv_bn_pool_2 --gamma_y=7e5 \ --nname_encoder_x2topic_h=conv_bn_pool_2 \ --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2 - - diff --git a/tests/test_fbopt.py b/tests/test_fbopt.py index 5d13404c6..1e2859291 100644 --- a/tests/test_fbopt.py +++ b/tests/test_fbopt.py @@ -19,6 +19,7 @@ def test_jigen_fbopt(): args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=jigen --trainer=fbopt --nname=alexnet --epos=3" utils_test_algo(args) + def test_diva_fbopt(): """ diva @@ -26,10 +27,10 @@ def test_diva_fbopt(): args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=diva --gamma_y=1.0 --trainer=fbopt --nname=alexnet --epos=3" utils_test_algo(args) + def test_forcesetpoint_fbopt(): """ diva """ args = "--te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --model=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=10 --es=0 --mu_init=0.00001 --coeff_ma_setpoint=0.5 --coeff_ma_output_state=0.99 --force_setpoint_change_once" utils_test_algo(args) - diff --git a/tests/test_fbopt_setpoint_ada.py b/tests/test_fbopt_setpoint_ada.py index 3eeca4d0a..4b8029056 100644 --- a/tests/test_fbopt_setpoint_ada.py +++ b/tests/test_fbopt_setpoint_ada.py @@ -1,7 +1,9 @@ from domainlab.algos.trainers.fbopt_setpoint_ada import is_less_list_all + + def test_less_than(): - a = [3, 4, -9, -8] - b = [1, 0.5, -1, -0.5] - c = [0.5, 0.25, -0.5, -0.25] - assert not is_less_list_all(a, b) - assert is_less_list_all(c, b) + a = [3, 4, -9, -8] + b = [1, 0.5, -1, -0.5] + c = [0.5, 0.25, -0.5, -0.25] + assert not is_less_list_all(a, b) + assert is_less_list_all(c, b) diff --git a/tests/test_fbopt_setpoint_rewind.py b/tests/test_fbopt_setpoint_rewind.py index e1b797c4d..3c1011bab 100644 --- a/tests/test_fbopt_setpoint_rewind.py +++ b/tests/test_fbopt_setpoint_rewind.py @@ -3,6 +3,7 @@ """ from tests.utils_test import utils_test_algo + def test_jigen_fbopt(): """ jigen