From e4d87d930e4034d50f7dea7414082bb63c4da99b Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Mon, 13 May 2024 17:18:48 +0200 Subject: [PATCH 01/22] added hyper init and hyper update function and a new benchmark for fbopt, fishr and erm --- domainlab/models/model_erm.py | 19 ++++++ examples/benchmark/pacs_fbopt_fishr_erm.yaml | 68 ++++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 examples/benchmark/pacs_fbopt_fishr_erm.yaml diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index 6ee9c23f9..778001578 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -53,4 +53,23 @@ def convert4backpack(self): """ self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) self.net_classifier = extend(self.net_classifier, use_converter=True) + + def hyper_update(self, epoch, fun_scheduler): + """hyper_update. + + :param epoch: + :param fun_scheduler: + """ + pass + + def hyper_init(self, functor_scheduler, trainer=None): + """ + initiate a scheduler object via class name and things inside this model + + :param functor_scheduler: the class name of the scheduler + """ + return functor_scheduler( + trainer=trainer + ) + return ModelERM diff --git a/examples/benchmark/pacs_fbopt_fishr_erm.yaml b/examples/benchmark/pacs_fbopt_fishr_erm.yaml new file mode 100644 index 000000000..6c26c516f --- /dev/null +++ b/examples/benchmark/pacs_fbopt_fishr_erm.yaml @@ -0,0 +1,68 @@ +mode: grid + +output_dir: zoutput/benchmarks/benchmark_fbopt_fishr_erm_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 200 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + ini_setpoint_ratio: + min: 0.5 + max: 0.99 + num: 2 + step: 0.05 + distribution: uniform + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + step: 0.0001 + distribution: uniform + + gamma_reg: + min: 0.01 + max: 1e4 + num: 3 + distribution: loguniform + + +# Test fbopt with different hyperparameter configurations + +fbopt_fishr_erm: + model: erm + trainer: fbopt_fishr + shared: + - ini_setpoint_ratio + - k_i_gain + - gamma_reg + +fishr_erm: + model: erm + trainer: fishr + shared: + - ini_setpoint_ratio + - k_i_gain + - gamma_reg \ No newline at end of file From 5c44550bab7ab0d90a719a1c5a4c3722632117ea Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 14 May 2024 08:42:42 +0200 Subject: [PATCH 02/22] Fixed import of backpack --- domainlab/models/a_model_classif.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/domainlab/models/a_model_classif.py b/domainlab/models/a_model_classif.py index 1917f752e..3dd9831d0 100644 --- a/domainlab/models/a_model_classif.py +++ b/domainlab/models/a_model_classif.py @@ -20,11 +20,11 @@ try: from backpack import extend -except: - backpack = None - -loss_cross_entropy_extended = extend(nn.CrossEntropyLoss(reduction="none")) - + loss_cross_entropy_extended = extend(nn.CrossEntropyLoss(reduction="none")) +except ImportError: + # Handle the case where backpack is not available + loss_cross_entropy_extended = nn.CrossEntropyLoss(reduction="none") + print("Backpack could not be imported. Using standard nn.CrossEntropyLoss.") class AModelClassif(AModel, metaclass=abc.ABCMeta): """ From 416389fe38bc6915fba1c97964a30ad3267572c5 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 14 May 2024 08:55:56 +0200 Subject: [PATCH 03/22] Fixed benchmark import not successfull in erm --- domainlab/models/model_erm.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index 778001578..5b8bb5cfb 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -7,8 +7,10 @@ try: from backpack import extend -except: - backpack = None +except ImportError as e: + print(f"Failed to import 'extend' from backpack: {e}") + extend = None # Ensure extend is defined to avoid NameError later in the code + def mk_erm(parent_class=AModelClassif, **kwargs): @@ -47,12 +49,16 @@ def __init__(self, net=None, net_feat=None): super().__init__(**kwargs) self._net_invar_feat = net_feat - def convert4backpack(self): - """ - convert the module to backpack for 2nd order gradients - """ + def convert4backpack(self): + """ + Convert the module to backpack for 2nd order gradients + """ + if extend is not None: self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) - self.net_classifier = extend(self.net_classifier, use_converter=True) + self.net_classifier = extend(self.net_classifier, use_converter=True) + else: + print("Backpack's extend function is not available.") + def hyper_update(self, epoch, fun_scheduler): """hyper_update. From d59d37ad1b634120553b9c721252fe67908f8da9 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 14 May 2024 08:59:21 +0200 Subject: [PATCH 04/22] fixed indentation --- domainlab/models/model_erm.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index 5b8bb5cfb..b4ec477a6 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -49,15 +49,15 @@ def __init__(self, net=None, net_feat=None): super().__init__(**kwargs) self._net_invar_feat = net_feat - def convert4backpack(self): - """ - Convert the module to backpack for 2nd order gradients - """ - if extend is not None: - self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) - self.net_classifier = extend(self.net_classifier, use_converter=True) - else: - print("Backpack's extend function is not available.") + def convert4backpack(self): + """ + Convert the module to backpack for 2nd order gradients + """ + if extend is not None: + self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) + self.net_classifier = extend(self.net_classifier, use_converter=True) + else: + print("Backpack's extend function is not available.") def hyper_update(self, epoch, fun_scheduler): From a4fa31dd9d00697c1a8406a39d7b9996a52bc741 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 14 May 2024 09:04:21 +0200 Subject: [PATCH 05/22] Added backpack check for fishr --- domainlab/algos/trainers/train_fishr.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/domainlab/algos/trainers/train_fishr.py b/domainlab/algos/trainers/train_fishr.py index 3580a0721..40a7bc50e 100644 --- a/domainlab/algos/trainers/train_fishr.py +++ b/domainlab/algos/trainers/train_fishr.py @@ -156,10 +156,15 @@ def cal_dict_variance_grads(self, tensor_x, vec_y): loss = self.model.cal_task_loss(tensor_x.clone(), vec_y) loss = loss.sum() - with backpack(Variance()): + if backpack is None: loss.backward( inputs=list(self.model.parameters()), retain_graph=True, create_graph=True ) + else: + with backpack(Variance()): + loss.backward( + inputs=list(self.model.parameters()), retain_graph=True, create_graph=True + ) for name, param in self.model.named_parameters(): print(name) From a38c777522e66cc8b7a446107c9b12192e43eb82 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 14 May 2024 09:24:39 +0200 Subject: [PATCH 06/22] Debugging backpack in erm --- domainlab/models/model_erm.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index b4ec477a6..b4ac550b6 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -50,15 +50,19 @@ def __init__(self, net=None, net_feat=None): self._net_invar_feat = net_feat def convert4backpack(self): - """ - Convert the module to backpack for 2nd order gradients - """ + print("Extending model components...") if extend is not None: + if hasattr(self._net_invar_feat, 'parameters'): + print("Net features before extend:", self._net_invar_feat) self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) + + if hasattr(self.net_classifier, 'parameters'): + print("Net classifier before extend:", self.net_classifier) self.net_classifier = extend(self.net_classifier, use_converter=True) else: print("Backpack's extend function is not available.") + def hyper_update(self, epoch, fun_scheduler): """hyper_update. From 870bc57a2064d99f8c5a88a95112450537535e7b Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 14 May 2024 10:16:55 +0200 Subject: [PATCH 07/22] Adding more logging to erm --- domainlab/models/model_erm.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index b4ac550b6..a253eaac1 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -4,6 +4,7 @@ from domainlab.compos.nn_zoo.nn import LayerId from domainlab.models.a_model_classif import AModelClassif from domainlab.utils.override_interface import override_interface +import traceback try: from backpack import extend @@ -52,9 +53,13 @@ def __init__(self, net=None, net_feat=None): def convert4backpack(self): print("Extending model components...") if extend is not None: - if hasattr(self._net_invar_feat, 'parameters'): - print("Net features before extend:", self._net_invar_feat) - self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) + try: + if hasattr(self._net_invar_feat, 'parameters'): + print("Net features before extend:", self._net_invar_feat) + self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) + except Exception as e: + print("An error occurred:", e) + traceback.print_exc() if hasattr(self.net_classifier, 'parameters'): print("Net classifier before extend:", self.net_classifier) From f7795a927c082e3a10ba8cb6aa79998298fe86fc Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 14 May 2024 11:49:51 +0200 Subject: [PATCH 08/22] added list_str_multiplier --- domainlab/models/model_erm.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index a253eaac1..d6ec3b35c 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -86,5 +86,12 @@ def hyper_init(self, functor_scheduler, trainer=None): return functor_scheduler( trainer=trainer ) + + @property + def list_str_multiplier_na(self): + """ + list of multipliers which match the order in cal_reg_loss + """ + return [] return ModelERM From 6c5ca27e91f7ca27db97aaed2802952c191789a3 Mon Sep 17 00:00:00 2001 From: Matteo Wohlrapp Date: Tue, 14 May 2024 16:03:08 +0200 Subject: [PATCH 09/22] Added directories to gitignore, adjusted fobt_fishr_erm benchmark --- .gitignore | 10 ++++++++-- examples/benchmark/pacs_fbopt_fishr_erm.yaml | 8 +++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 2e1aa7a74..a351b31fe 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,13 @@ .ropeproject -./zdpath -./zoutput +/zdpath +/zoutput tests/__pycache__/ *.pyc .vscode/ domainlab/zdata/pacs +/data/ +/.snakemake/ +/dist +/domainlab.egg-info +/runs +/slurm_errors.txt \ No newline at end of file diff --git a/examples/benchmark/pacs_fbopt_fishr_erm.yaml b/examples/benchmark/pacs_fbopt_fishr_erm.yaml index 6c26c516f..721833e7b 100644 --- a/examples/benchmark/pacs_fbopt_fishr_erm.yaml +++ b/examples/benchmark/pacs_fbopt_fishr_erm.yaml @@ -18,8 +18,8 @@ domainlab_args: es: 5 bs: 32 san_check: False - npath: examples/nets/resnet50domainbed.py - npath_dom: examples/nets/resnet50domainbed.py + nname: alexnet + nname_dom: alexnet zx_dim: 0 zy_dim: 64 zd_dim: 64 @@ -63,6 +63,4 @@ fishr_erm: model: erm trainer: fishr shared: - - ini_setpoint_ratio - - k_i_gain - - gamma_reg \ No newline at end of file + - gamma_reg From 38fb5f20ac72a9b213ea42f80e2d5d713ac49c7b Mon Sep 17 00:00:00 2001 From: Matteo Wohlrapp Date: Wed, 22 May 2024 09:39:53 +0200 Subject: [PATCH 10/22] Solved indexing issue in fbopt mu controller and added flag info to train epoch. Reverted prior backpack changes --- .../algos/trainers/fbopt_mu_controller.py | 49 +++++++++++-------- domainlab/algos/trainers/train_fishr.py | 17 ++----- .../algos/trainers/train_hyper_scheduler.py | 2 +- domainlab/algos/trainers/train_irm.py | 4 +- domainlab/algos/trainers/train_matchdg.py | 2 +- domainlab/algos/trainers/train_mldg.py | 4 +- domainlab/models/a_model_classif.py | 9 ++-- domainlab/models/model_erm.py | 30 +++--------- 8 files changed, 51 insertions(+), 66 deletions(-) diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index 9f8e02971..793395f03 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -206,27 +206,36 @@ def search_mu( self.writer.add_scalar(f"controller_gain/{key}", dict_gain[key], miter) ind = list_str_multiplier_na.index(key) self.writer.add_scalar(f"delta/{key}", self.delta_epsilon_r[ind], miter) - for i, (reg_dyn, reg_set) in enumerate( - zip(epo_reg_loss, self.get_setpoint4r()) - ): - self.writer.add_scalar( - f"lossrd/dyn_{list_str_multiplier_na[i]}", reg_dyn, miter - ) - self.writer.add_scalar( - f"lossrs/setpoint_{list_str_multiplier_na[i]}", reg_set, miter - ) + + print(f"Epo Reg Loss: {epo_reg_loss}") + print(f"Setpoint: {self.get_setpoint4r()}") + print(f"List: {list_str_multiplier_na}") - self.writer.add_scalars( - f"loss_rds/loss_{list_str_multiplier_na[i]}_w_setpoint", - { - f"lossr/loss_{list_str_multiplier_na[i]}": reg_dyn, - f"lossr/setpoint_{list_str_multiplier_na[i]}": reg_set, - }, - miter, - ) - self.writer.add_scalar( - f"x_ell_y_r/loss_{list_str_multiplier_na[i]}", reg_dyn, epo_task_loss - ) + if list_str_multiplier_na: + for i, (reg_dyn, reg_set) in enumerate( + zip(epo_reg_loss, self.get_setpoint4r()) + ): + + self.writer.add_scalar( + f"lossrd/dyn_{list_str_multiplier_na[i]}", reg_dyn, miter + ) + self.writer.add_scalar( + f"lossrs/setpoint_{list_str_multiplier_na[i]}", reg_set, miter + ) + + self.writer.add_scalars( + f"loss_rds/loss_{list_str_multiplier_na[i]}_w_setpoint", + { + f"lossr/loss_{list_str_multiplier_na[i]}": reg_dyn, + f"lossr/setpoint_{list_str_multiplier_na[i]}": reg_set, + }, + miter, + ) + self.writer.add_scalar( + f"x_ell_y_r/loss_{list_str_multiplier_na[i]}", reg_dyn, epo_task_loss + ) + else: + logger.info("No multiplier provided") self.writer.add_scalar("loss_task/penalized", epo_loss_tr, miter) self.writer.add_scalar("loss_task/ell", epo_task_loss, miter) acc_te = 0 diff --git a/domainlab/algos/trainers/train_fishr.py b/domainlab/algos/trainers/train_fishr.py index 40a7bc50e..bd67aa001 100644 --- a/domainlab/algos/trainers/train_fishr.py +++ b/domainlab/algos/trainers/train_fishr.py @@ -26,7 +26,7 @@ class TrainerFishr(TrainerBasic): "Fishr: Invariant gradient variances for out-of-distribution generalization." International Conference on Machine Learning. PMLR, 2022. """ - def tr_epoch(self, epoch): + def tr_epoch(self, epoch, flag_info=False): list_loaders = list(self.dict_loader_tr.values()) loaders_zip = zip(*list_loaders) self.model.train() @@ -46,7 +46,7 @@ def tr_epoch(self, epoch): self.epo_loss_tr += loss.detach().item() self.after_batch(epoch, ind_batch) - flag_stop = self.observer.update(epoch) # notify observer + flag_stop = self.observer.update(epoch, flag_info) # notify observer return flag_stop def var_grads_and_loss(self, tuple_data_domains_batch): @@ -156,20 +156,11 @@ def cal_dict_variance_grads(self, tensor_x, vec_y): loss = self.model.cal_task_loss(tensor_x.clone(), vec_y) loss = loss.sum() - if backpack is None: + with backpack(Variance()): loss.backward( inputs=list(self.model.parameters()), retain_graph=True, create_graph=True ) - else: - with backpack(Variance()): - loss.backward( - inputs=list(self.model.parameters()), retain_graph=True, create_graph=True - ) - - for name, param in self.model.named_parameters(): - print(name) - print(".grad.shape: ", param.variance.shape) - + dict_variance = OrderedDict( [(name, weights.variance.clone()) for name, weights in self.model.named_parameters() diff --git a/domainlab/algos/trainers/train_hyper_scheduler.py b/domainlab/algos/trainers/train_hyper_scheduler.py index 2e60bf5e8..9eda2e494 100644 --- a/domainlab/algos/trainers/train_hyper_scheduler.py +++ b/domainlab/algos/trainers/train_hyper_scheduler.py @@ -55,7 +55,7 @@ def before_tr(self): flag_update_epoch=True, ) - def tr_epoch(self, epoch): + def tr_epoch(self, epoch, flag_info=False): """ update hyper-parameters only per epoch """ diff --git a/domainlab/algos/trainers/train_irm.py b/domainlab/algos/trainers/train_irm.py index 0da2d0bde..9e91f1edc 100644 --- a/domainlab/algos/trainers/train_irm.py +++ b/domainlab/algos/trainers/train_irm.py @@ -19,7 +19,7 @@ class TrainerIRM(TrainerBasic): For more details, see section 3.2 and Appendix D of : Arjovsky et al., “Invariant Risk Minimization.” """ - def tr_epoch(self, epoch): + def tr_epoch(self, epoch, flag_info=False): list_loaders = list(self.dict_loader_tr.values()) loaders_zip = zip(*list_loaders) self.model.train() @@ -46,7 +46,7 @@ def tr_epoch(self, epoch): self.epo_loss_tr += loss.detach().item() self.after_batch(epoch, ind_batch) - flag_stop = self.observer.update(epoch) # notify observer + flag_stop = self.observer.update(epoch, flag_info) # notify observer return flag_stop def _cal_phi(self, tensor_x): diff --git a/domainlab/algos/trainers/train_matchdg.py b/domainlab/algos/trainers/train_matchdg.py index e127baa12..ef4b1c862 100644 --- a/domainlab/algos/trainers/train_matchdg.py +++ b/domainlab/algos/trainers/train_matchdg.py @@ -95,7 +95,7 @@ def tr_epoch(self, epoch, flag_info=False): logger.info("\n\nPhase erm+ctr \n\n") self.flag_erm = True - flag_stop = self.observer.update(epoch) # notify observer + flag_stop = self.observer.update(epoch, flag_info) # notify observer return flag_stop def tr_batch(self, epoch, batch_idx, x_e, y_e, d_e, others=None): diff --git a/domainlab/algos/trainers/train_mldg.py b/domainlab/algos/trainers/train_mldg.py index e91310adf..9e13ff58b 100644 --- a/domainlab/algos/trainers/train_mldg.py +++ b/domainlab/algos/trainers/train_mldg.py @@ -51,7 +51,7 @@ def prepare_ziped_loader(self): ddset_mix = DsetZip(ddset_source, ddset_target) self.loader_tr_source_target = mk_loader(ddset_mix, self.aconf.bs) - def tr_epoch(self, epoch): + def tr_epoch(self, epoch, flag_info=False): self.model.train() self.epo_loss_tr = 0 self.prepare_ziped_loader() @@ -117,5 +117,5 @@ def tr_epoch(self, epoch): self.optimizer.step() self.epo_loss_tr += loss.detach().item() self.after_batch(epoch, ind_batch) - flag_stop = self.observer.update(epoch) # notify observer + flag_stop = self.observer.update(epoch, flag_info) # notify observer return flag_stop diff --git a/domainlab/models/a_model_classif.py b/domainlab/models/a_model_classif.py index 3dd9831d0..1f72eec0a 100644 --- a/domainlab/models/a_model_classif.py +++ b/domainlab/models/a_model_classif.py @@ -20,11 +20,10 @@ try: from backpack import extend - loss_cross_entropy_extended = extend(nn.CrossEntropyLoss(reduction="none")) -except ImportError: - # Handle the case where backpack is not available - loss_cross_entropy_extended = nn.CrossEntropyLoss(reduction="none") - print("Backpack could not be imported. Using standard nn.CrossEntropyLoss.") +except: + backpack = None + +loss_cross_entropy_extended = extend(nn.CrossEntropyLoss(reduction="none")) class AModelClassif(AModel, metaclass=abc.ABCMeta): """ diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index d6ec3b35c..a47cb9844 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -4,13 +4,11 @@ from domainlab.compos.nn_zoo.nn import LayerId from domainlab.models.a_model_classif import AModelClassif from domainlab.utils.override_interface import override_interface -import traceback try: from backpack import extend -except ImportError as e: - print(f"Failed to import 'extend' from backpack: {e}") - extend = None # Ensure extend is defined to avoid NameError later in the code +except: + backpack = None @@ -51,24 +49,12 @@ def __init__(self, net=None, net_feat=None): self._net_invar_feat = net_feat def convert4backpack(self): - print("Extending model components...") - if extend is not None: - try: - if hasattr(self._net_invar_feat, 'parameters'): - print("Net features before extend:", self._net_invar_feat) - self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) - except Exception as e: - print("An error occurred:", e) - traceback.print_exc() - - if hasattr(self.net_classifier, 'parameters'): - print("Net classifier before extend:", self.net_classifier) - self.net_classifier = extend(self.net_classifier, use_converter=True) - else: - print("Backpack's extend function is not available.") - - - + """ + convert the module to backpack for 2nd order gradients + """ + self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) + self.net_classifier = extend(self.net_classifier, use_converter=True) + def hyper_update(self, epoch, fun_scheduler): """hyper_update. From caf32b7cb674d7d54268ab08f15a84cff9caf644 Mon Sep 17 00:00:00 2001 From: Matteo Wohlrapp Date: Wed, 22 May 2024 09:42:56 +0200 Subject: [PATCH 11/22] removed prints --- domainlab/algos/trainers/fbopt_mu_controller.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index 793395f03..25432bb45 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -206,10 +206,6 @@ def search_mu( self.writer.add_scalar(f"controller_gain/{key}", dict_gain[key], miter) ind = list_str_multiplier_na.index(key) self.writer.add_scalar(f"delta/{key}", self.delta_epsilon_r[ind], miter) - - print(f"Epo Reg Loss: {epo_reg_loss}") - print(f"Setpoint: {self.get_setpoint4r()}") - print(f"List: {list_str_multiplier_na}") if list_str_multiplier_na: for i, (reg_dyn, reg_set) in enumerate( From a098a681bbd72c13fffe5f0da1aa18858b7d4b58 Mon Sep 17 00:00:00 2001 From: Matteo Wohlrapp Date: Wed, 22 May 2024 09:50:28 +0200 Subject: [PATCH 12/22] Fixed indentation for convert4backpack --- domainlab/models/model_erm.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index a47cb9844..48330f6fa 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -10,8 +10,6 @@ except: backpack = None - - def mk_erm(parent_class=AModelClassif, **kwargs): """ Instantiate a Deepall (ERM) model @@ -49,11 +47,11 @@ def __init__(self, net=None, net_feat=None): self._net_invar_feat = net_feat def convert4backpack(self): - """ - convert the module to backpack for 2nd order gradients - """ - self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) - self.net_classifier = extend(self.net_classifier, use_converter=True) + """ + convert the module to backpack for 2nd order gradients + """ + self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) + self.net_classifier = extend(self.net_classifier, use_converter=True) def hyper_update(self, epoch, fun_scheduler): """hyper_update. From 518d919bb3459b80fc492879219b8bd8188a3b81 Mon Sep 17 00:00:00 2001 From: Matteo Wohlrapp Date: Tue, 28 May 2024 11:39:18 +0200 Subject: [PATCH 13/22] fixed codacity --- domainlab/algos/trainers/fbopt_mu_controller.py | 2 +- domainlab/algos/trainers/train_fishr.py | 2 +- domainlab/models/model_erm.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index 25432bb45..824638461 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -230,7 +230,7 @@ def search_mu( self.writer.add_scalar( f"x_ell_y_r/loss_{list_str_multiplier_na[i]}", reg_dyn, epo_task_loss ) - else: + else: logger.info("No multiplier provided") self.writer.add_scalar("loss_task/penalized", epo_loss_tr, miter) self.writer.add_scalar("loss_task/ell", epo_task_loss, miter) diff --git a/domainlab/algos/trainers/train_fishr.py b/domainlab/algos/trainers/train_fishr.py index bd67aa001..e8d0c71b8 100644 --- a/domainlab/algos/trainers/train_fishr.py +++ b/domainlab/algos/trainers/train_fishr.py @@ -160,7 +160,7 @@ def cal_dict_variance_grads(self, tensor_x, vec_y): loss.backward( inputs=list(self.model.parameters()), retain_graph=True, create_graph=True ) - + dict_variance = OrderedDict( [(name, weights.variance.clone()) for name, weights in self.model.named_parameters() diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index 48330f6fa..f863da655 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -59,7 +59,7 @@ def hyper_update(self, epoch, fun_scheduler): :param epoch: :param fun_scheduler: """ - pass + ... def hyper_init(self, functor_scheduler, trainer=None): """ @@ -70,7 +70,7 @@ def hyper_init(self, functor_scheduler, trainer=None): return functor_scheduler( trainer=trainer ) - + @property def list_str_multiplier_na(self): """ From e947f67176c1be3ea9cae338c6f0c2c1531f7228 Mon Sep 17 00:00:00 2001 From: Matteo Wohlrapp Date: Tue, 28 May 2024 12:15:52 +0200 Subject: [PATCH 14/22] fixed codacity --- domainlab/algos/trainers/train_fishr.py | 2 +- domainlab/models/model_erm.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/domainlab/algos/trainers/train_fishr.py b/domainlab/algos/trainers/train_fishr.py index e8d0c71b8..250b4109d 100644 --- a/domainlab/algos/trainers/train_fishr.py +++ b/domainlab/algos/trainers/train_fishr.py @@ -160,7 +160,7 @@ def cal_dict_variance_grads(self, tensor_x, vec_y): loss.backward( inputs=list(self.model.parameters()), retain_graph=True, create_graph=True ) - + dict_variance = OrderedDict( [(name, weights.variance.clone()) for name, weights in self.model.named_parameters() diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index f863da655..5e491a5b7 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -53,13 +53,14 @@ def convert4backpack(self): self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) self.net_classifier = extend(self.net_classifier, use_converter=True) - def hyper_update(self, epoch, fun_scheduler): + def hyper_update(self, _epoch, _fun_scheduler): """hyper_update. :param epoch: :param fun_scheduler: """ - ... + """Method necessary to combine with hyperparameter scheduler""" + return def hyper_init(self, functor_scheduler, trainer=None): """ @@ -70,7 +71,7 @@ def hyper_init(self, functor_scheduler, trainer=None): return functor_scheduler( trainer=trainer ) - + @property def list_str_multiplier_na(self): """ From 4f1347a9c850f70adcb9276e6aad59530f0440de Mon Sep 17 00:00:00 2001 From: Matteo Wohlrapp Date: Tue, 11 Jun 2024 15:38:50 +0200 Subject: [PATCH 15/22] fixed codacity --- domainlab/models/model_erm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index 5e491a5b7..bc4f6e45f 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -54,13 +54,12 @@ def convert4backpack(self): self.net_classifier = extend(self.net_classifier, use_converter=True) def hyper_update(self, _epoch, _fun_scheduler): - """hyper_update. + """ + Method necessary to combine with hyperparameter scheduler :param epoch: :param fun_scheduler: """ - """Method necessary to combine with hyperparameter scheduler""" - return def hyper_init(self, functor_scheduler, trainer=None): """ From f2fa952ba2f020adeac830abddd4774e3a5d9054 Mon Sep 17 00:00:00 2001 From: Matteo Wohlrapp Date: Tue, 11 Jun 2024 16:07:53 +0200 Subject: [PATCH 16/22] fixed codacity --- domainlab/models/model_erm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index bc4f6e45f..4dc4c03f3 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -53,7 +53,7 @@ def convert4backpack(self): self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) self.net_classifier = extend(self.net_classifier, use_converter=True) - def hyper_update(self, _epoch, _fun_scheduler): + def hyper_update(self, epoch, fun_scheduler): # noqa: F841 # pylint: disable=unused-argument """ Method necessary to combine with hyperparameter scheduler From 279f5102ea62d508eb9fe5b4a11ddf714c8170e2 Mon Sep 17 00:00:00 2001 From: Matteo Wohlrapp Date: Tue, 11 Jun 2024 16:16:05 +0200 Subject: [PATCH 17/22] Added test for erm functions --- tests/test_fbopt.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_fbopt.py b/tests/test_fbopt.py index 1e2859291..cbf9b8d98 100644 --- a/tests/test_fbopt.py +++ b/tests/test_fbopt.py @@ -27,6 +27,12 @@ def test_diva_fbopt(): args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=diva --gamma_y=1.0 --trainer=fbopt --nname=alexnet --epos=3" utils_test_algo(args) +def test_erm_fbopt(): + """ + erm + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt --nname=alexnet --epos=3" + utils_test_algo(args) def test_forcesetpoint_fbopt(): """ From f8f4f0edaecabe37349ae52eab5881b330e8d30d Mon Sep 17 00:00:00 2001 From: Matteo Wohlrapp Date: Tue, 11 Jun 2024 16:31:30 +0200 Subject: [PATCH 18/22] Disabling line too long for argument --- domainlab/models/model_erm.py | 2 +- tests/test_fbopt.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index 4dc4c03f3..4ccec7a50 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -53,7 +53,7 @@ def convert4backpack(self): self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) self.net_classifier = extend(self.net_classifier, use_converter=True) - def hyper_update(self, epoch, fun_scheduler): # noqa: F841 # pylint: disable=unused-argument + def hyper_update(self, epoch, fun_scheduler): # pylint: disable=unused-argument """ Method necessary to combine with hyperparameter scheduler diff --git a/tests/test_fbopt.py b/tests/test_fbopt.py index cbf9b8d98..c442bf090 100644 --- a/tests/test_fbopt.py +++ b/tests/test_fbopt.py @@ -31,7 +31,7 @@ def test_erm_fbopt(): """ erm """ - args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt --nname=alexnet --epos=3" + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt --nname=alexnet --epos=3" # pylint: disable=line-too-long utils_test_algo(args) def test_forcesetpoint_fbopt(): From d1ccf46511628a8a440328851acbd6a2a0bb50ec Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Tue, 2 Jul 2024 12:37:58 +0200 Subject: [PATCH 19/22] Update task_pacs_aug.py, update path of PACS --- examples/tasks/task_pacs_aug.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/tasks/task_pacs_aug.py b/examples/tasks/task_pacs_aug.py index e971bea8c..f57b03fc6 100644 --- a/examples/tasks/task_pacs_aug.py +++ b/examples/tasks/task_pacs_aug.py @@ -11,9 +11,9 @@ from domainlab.tasks.utils_task import ImSize # change this to absolute directory where you have the raw images from PACS, -G_PACS_RAW_PATH = "domainlab/zdata/pacs/PACS" +G_PACS_RAW_PATH = "data/pacs/PACS" # domainlab repository contain already the file names in -# domainlab/zdata/pacs_split folder of domainlab +# domainlab/zdata/pacs_split folder of domainlab, but PACS dataset is too big to put into domainlab folder def get_task(na=None): From 6acdfdd23429f581f09a70182c94287ab1bfd00b Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Tue, 2 Jul 2024 12:39:18 +0200 Subject: [PATCH 20/22] Update pacs_fbopt_fishr_erm.yaml --- examples/benchmark/pacs_fbopt_fishr_erm.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmark/pacs_fbopt_fishr_erm.yaml b/examples/benchmark/pacs_fbopt_fishr_erm.yaml index 721833e7b..f3c10eb2b 100644 --- a/examples/benchmark/pacs_fbopt_fishr_erm.yaml +++ b/examples/benchmark/pacs_fbopt_fishr_erm.yaml @@ -4,7 +4,7 @@ output_dir: zoutput/benchmarks/benchmark_fbopt_fishr_erm_pacs sampling_seed: 0 startseed: 0 -endseed: 2 +endseed: 0 test_domains: - sketch From d60a789d1257b8449c34ae8bce0ef978c2e9e82d Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Tue, 2 Jul 2024 12:40:20 +0200 Subject: [PATCH 21/22] Update pacs_fbopt_fishr_erm.yaml --- examples/benchmark/pacs_fbopt_fishr_erm.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/benchmark/pacs_fbopt_fishr_erm.yaml b/examples/benchmark/pacs_fbopt_fishr_erm.yaml index f3c10eb2b..781a2518e 100644 --- a/examples/benchmark/pacs_fbopt_fishr_erm.yaml +++ b/examples/benchmark/pacs_fbopt_fishr_erm.yaml @@ -13,8 +13,8 @@ domainlab_args: tpath: examples/tasks/task_pacs_aug.py dmem: False lr: 5e-5 - epos: 500 - epos_min: 200 + epos: 10 + epos_min: 2 es: 5 bs: 32 san_check: False From 976c25a5571d7ae9ba3dac9619bba96b186ffc1d Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Tue, 2 Jul 2024 12:59:08 +0200 Subject: [PATCH 22/22] Update task_pacs_aug.py, fix codacy --- examples/tasks/task_pacs_aug.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/tasks/task_pacs_aug.py b/examples/tasks/task_pacs_aug.py index f57b03fc6..0d334a45a 100644 --- a/examples/tasks/task_pacs_aug.py +++ b/examples/tasks/task_pacs_aug.py @@ -13,7 +13,8 @@ # change this to absolute directory where you have the raw images from PACS, G_PACS_RAW_PATH = "data/pacs/PACS" # domainlab repository contain already the file names in -# domainlab/zdata/pacs_split folder of domainlab, but PACS dataset is too big to put into domainlab folder +# domainlab/zdata/pacs_split folder of domainlab, +# but PACS dataset is too big to put into domainlab folder def get_task(na=None):