diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index 9f8e02971..824638461 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -206,27 +206,32 @@ def search_mu( self.writer.add_scalar(f"controller_gain/{key}", dict_gain[key], miter) ind = list_str_multiplier_na.index(key) self.writer.add_scalar(f"delta/{key}", self.delta_epsilon_r[ind], miter) - for i, (reg_dyn, reg_set) in enumerate( - zip(epo_reg_loss, self.get_setpoint4r()) - ): - self.writer.add_scalar( - f"lossrd/dyn_{list_str_multiplier_na[i]}", reg_dyn, miter - ) - self.writer.add_scalar( - f"lossrs/setpoint_{list_str_multiplier_na[i]}", reg_set, miter - ) - self.writer.add_scalars( - f"loss_rds/loss_{list_str_multiplier_na[i]}_w_setpoint", - { - f"lossr/loss_{list_str_multiplier_na[i]}": reg_dyn, - f"lossr/setpoint_{list_str_multiplier_na[i]}": reg_set, - }, - miter, - ) - self.writer.add_scalar( - f"x_ell_y_r/loss_{list_str_multiplier_na[i]}", reg_dyn, epo_task_loss - ) + if list_str_multiplier_na: + for i, (reg_dyn, reg_set) in enumerate( + zip(epo_reg_loss, self.get_setpoint4r()) + ): + + self.writer.add_scalar( + f"lossrd/dyn_{list_str_multiplier_na[i]}", reg_dyn, miter + ) + self.writer.add_scalar( + f"lossrs/setpoint_{list_str_multiplier_na[i]}", reg_set, miter + ) + + self.writer.add_scalars( + f"loss_rds/loss_{list_str_multiplier_na[i]}_w_setpoint", + { + f"lossr/loss_{list_str_multiplier_na[i]}": reg_dyn, + f"lossr/setpoint_{list_str_multiplier_na[i]}": reg_set, + }, + miter, + ) + self.writer.add_scalar( + f"x_ell_y_r/loss_{list_str_multiplier_na[i]}", reg_dyn, epo_task_loss + ) + else: + logger.info("No multiplier provided") self.writer.add_scalar("loss_task/penalized", epo_loss_tr, miter) self.writer.add_scalar("loss_task/ell", epo_task_loss, miter) acc_te = 0 diff --git a/domainlab/algos/trainers/train_fishr.py b/domainlab/algos/trainers/train_fishr.py index 3580a0721..250b4109d 100644 --- a/domainlab/algos/trainers/train_fishr.py +++ b/domainlab/algos/trainers/train_fishr.py @@ -26,7 +26,7 @@ class TrainerFishr(TrainerBasic): "Fishr: Invariant gradient variances for out-of-distribution generalization." International Conference on Machine Learning. PMLR, 2022. """ - def tr_epoch(self, epoch): + def tr_epoch(self, epoch, flag_info=False): list_loaders = list(self.dict_loader_tr.values()) loaders_zip = zip(*list_loaders) self.model.train() @@ -46,7 +46,7 @@ def tr_epoch(self, epoch): self.epo_loss_tr += loss.detach().item() self.after_batch(epoch, ind_batch) - flag_stop = self.observer.update(epoch) # notify observer + flag_stop = self.observer.update(epoch, flag_info) # notify observer return flag_stop def var_grads_and_loss(self, tuple_data_domains_batch): @@ -161,10 +161,6 @@ def cal_dict_variance_grads(self, tensor_x, vec_y): inputs=list(self.model.parameters()), retain_graph=True, create_graph=True ) - for name, param in self.model.named_parameters(): - print(name) - print(".grad.shape: ", param.variance.shape) - dict_variance = OrderedDict( [(name, weights.variance.clone()) for name, weights in self.model.named_parameters() diff --git a/domainlab/algos/trainers/train_hyper_scheduler.py b/domainlab/algos/trainers/train_hyper_scheduler.py index 2e60bf5e8..9eda2e494 100644 --- a/domainlab/algos/trainers/train_hyper_scheduler.py +++ b/domainlab/algos/trainers/train_hyper_scheduler.py @@ -55,7 +55,7 @@ def before_tr(self): flag_update_epoch=True, ) - def tr_epoch(self, epoch): + def tr_epoch(self, epoch, flag_info=False): """ update hyper-parameters only per epoch """ diff --git a/domainlab/algos/trainers/train_irm.py b/domainlab/algos/trainers/train_irm.py index 0da2d0bde..9e91f1edc 100644 --- a/domainlab/algos/trainers/train_irm.py +++ b/domainlab/algos/trainers/train_irm.py @@ -19,7 +19,7 @@ class TrainerIRM(TrainerBasic): For more details, see section 3.2 and Appendix D of : Arjovsky et al., “Invariant Risk Minimization.” """ - def tr_epoch(self, epoch): + def tr_epoch(self, epoch, flag_info=False): list_loaders = list(self.dict_loader_tr.values()) loaders_zip = zip(*list_loaders) self.model.train() @@ -46,7 +46,7 @@ def tr_epoch(self, epoch): self.epo_loss_tr += loss.detach().item() self.after_batch(epoch, ind_batch) - flag_stop = self.observer.update(epoch) # notify observer + flag_stop = self.observer.update(epoch, flag_info) # notify observer return flag_stop def _cal_phi(self, tensor_x): diff --git a/domainlab/algos/trainers/train_matchdg.py b/domainlab/algos/trainers/train_matchdg.py index e127baa12..ef4b1c862 100644 --- a/domainlab/algos/trainers/train_matchdg.py +++ b/domainlab/algos/trainers/train_matchdg.py @@ -95,7 +95,7 @@ def tr_epoch(self, epoch, flag_info=False): logger.info("\n\nPhase erm+ctr \n\n") self.flag_erm = True - flag_stop = self.observer.update(epoch) # notify observer + flag_stop = self.observer.update(epoch, flag_info) # notify observer return flag_stop def tr_batch(self, epoch, batch_idx, x_e, y_e, d_e, others=None): diff --git a/domainlab/algos/trainers/train_mldg.py b/domainlab/algos/trainers/train_mldg.py index e91310adf..9e13ff58b 100644 --- a/domainlab/algos/trainers/train_mldg.py +++ b/domainlab/algos/trainers/train_mldg.py @@ -51,7 +51,7 @@ def prepare_ziped_loader(self): ddset_mix = DsetZip(ddset_source, ddset_target) self.loader_tr_source_target = mk_loader(ddset_mix, self.aconf.bs) - def tr_epoch(self, epoch): + def tr_epoch(self, epoch, flag_info=False): self.model.train() self.epo_loss_tr = 0 self.prepare_ziped_loader() @@ -117,5 +117,5 @@ def tr_epoch(self, epoch): self.optimizer.step() self.epo_loss_tr += loss.detach().item() self.after_batch(epoch, ind_batch) - flag_stop = self.observer.update(epoch) # notify observer + flag_stop = self.observer.update(epoch, flag_info) # notify observer return flag_stop diff --git a/domainlab/models/a_model_classif.py b/domainlab/models/a_model_classif.py index 1917f752e..1f72eec0a 100644 --- a/domainlab/models/a_model_classif.py +++ b/domainlab/models/a_model_classif.py @@ -25,7 +25,6 @@ loss_cross_entropy_extended = extend(nn.CrossEntropyLoss(reduction="none")) - class AModelClassif(AModel, metaclass=abc.ABCMeta): """ operations that all classification model should have diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index 6ee9c23f9..4ccec7a50 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -10,7 +10,6 @@ except: backpack = None - def mk_erm(parent_class=AModelClassif, **kwargs): """ Instantiate a Deepall (ERM) model @@ -53,4 +52,30 @@ def convert4backpack(self): """ self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) self.net_classifier = extend(self.net_classifier, use_converter=True) + + def hyper_update(self, epoch, fun_scheduler): # pylint: disable=unused-argument + """ + Method necessary to combine with hyperparameter scheduler + + :param epoch: + :param fun_scheduler: + """ + + def hyper_init(self, functor_scheduler, trainer=None): + """ + initiate a scheduler object via class name and things inside this model + + :param functor_scheduler: the class name of the scheduler + """ + return functor_scheduler( + trainer=trainer + ) + + @property + def list_str_multiplier_na(self): + """ + list of multipliers which match the order in cal_reg_loss + """ + return [] + return ModelERM diff --git a/examples/benchmark/pacs_fbopt_fishr_erm.yaml b/examples/benchmark/pacs_fbopt_fishr_erm.yaml new file mode 100644 index 000000000..781a2518e --- /dev/null +++ b/examples/benchmark/pacs_fbopt_fishr_erm.yaml @@ -0,0 +1,66 @@ +mode: grid + +output_dir: zoutput/benchmarks/benchmark_fbopt_fishr_erm_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 0 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 10 + epos_min: 2 + es: 5 + bs: 32 + san_check: False + nname: alexnet + nname_dom: alexnet + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + ini_setpoint_ratio: + min: 0.5 + max: 0.99 + num: 2 + step: 0.05 + distribution: uniform + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + step: 0.0001 + distribution: uniform + + gamma_reg: + min: 0.01 + max: 1e4 + num: 3 + distribution: loguniform + + +# Test fbopt with different hyperparameter configurations + +fbopt_fishr_erm: + model: erm + trainer: fbopt_fishr + shared: + - ini_setpoint_ratio + - k_i_gain + - gamma_reg + +fishr_erm: + model: erm + trainer: fishr + shared: + - gamma_reg diff --git a/examples/tasks/task_pacs_aug.py b/examples/tasks/task_pacs_aug.py index e971bea8c..0d334a45a 100644 --- a/examples/tasks/task_pacs_aug.py +++ b/examples/tasks/task_pacs_aug.py @@ -11,9 +11,10 @@ from domainlab.tasks.utils_task import ImSize # change this to absolute directory where you have the raw images from PACS, -G_PACS_RAW_PATH = "domainlab/zdata/pacs/PACS" +G_PACS_RAW_PATH = "data/pacs/PACS" # domainlab repository contain already the file names in -# domainlab/zdata/pacs_split folder of domainlab +# domainlab/zdata/pacs_split folder of domainlab, +# but PACS dataset is too big to put into domainlab folder def get_task(na=None): diff --git a/tests/test_fbopt.py b/tests/test_fbopt.py index 1e2859291..c442bf090 100644 --- a/tests/test_fbopt.py +++ b/tests/test_fbopt.py @@ -27,6 +27,12 @@ def test_diva_fbopt(): args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=diva --gamma_y=1.0 --trainer=fbopt --nname=alexnet --epos=3" utils_test_algo(args) +def test_erm_fbopt(): + """ + erm + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt --nname=alexnet --epos=3" # pylint: disable=line-too-long + utils_test_algo(args) def test_forcesetpoint_fbopt(): """