Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e4d87d9
added hyper init and hyper update function and a new benchmark for fb…
MatteoWohlrapp May 13, 2024
5c44550
Fixed import of backpack
MatteoWohlrapp May 14, 2024
416389f
Fixed benchmark import not successfull in erm
MatteoWohlrapp May 14, 2024
d59d37a
fixed indentation
MatteoWohlrapp May 14, 2024
a4fa31d
Added backpack check for fishr
MatteoWohlrapp May 14, 2024
a38c777
Debugging backpack in erm
MatteoWohlrapp May 14, 2024
870bc57
Adding more logging to erm
MatteoWohlrapp May 14, 2024
f7795a9
added list_str_multiplier
MatteoWohlrapp May 14, 2024
6c5ca27
Added directories to gitignore, adjusted fobt_fishr_erm benchmark
May 14, 2024
38fb5f2
Solved indexing issue in fbopt mu controller and added flag info to t…
May 22, 2024
caf32b7
removed prints
May 22, 2024
a098a68
Fixed indentation for convert4backpack
May 22, 2024
518d919
fixed codacity
May 28, 2024
e947f67
fixed codacity
May 28, 2024
e6a6f3e
Merge branch 'mhof_dev_merge' into erm_hyper_init
May 28, 2024
5288432
Merge branch 'mhof_dev_merge' into erm_hyper_init
MatteoWohlrapp Jun 11, 2024
4f1347a
fixed codacity
Jun 11, 2024
6a81b4c
Merge branch 'erm_hyper_init' of https://github.com/marrlab/DomainLab…
Jun 11, 2024
f2fa952
fixed codacity
Jun 11, 2024
279f510
Added test for erm functions
Jun 11, 2024
f8f4f0e
Disabling line too long for argument
Jun 11, 2024
c7caeb5
Merge branch 'mhof_dev_merge' into erm_hyper_init
MatteoWohlrapp Jul 2, 2024
d1ccf46
Update task_pacs_aug.py, update path of PACS
smilesun Jul 2, 2024
6acdfdd
Update pacs_fbopt_fishr_erm.yaml
smilesun Jul 2, 2024
d60a789
Update pacs_fbopt_fishr_erm.yaml
smilesun Jul 2, 2024
976c25a
Update task_pacs_aug.py, fix codacy
smilesun Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 25 additions & 20 deletions domainlab/algos/trainers/fbopt_mu_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,27 +206,32 @@ def search_mu(
self.writer.add_scalar(f"controller_gain/{key}", dict_gain[key], miter)
ind = list_str_multiplier_na.index(key)
self.writer.add_scalar(f"delta/{key}", self.delta_epsilon_r[ind], miter)
for i, (reg_dyn, reg_set) in enumerate(
zip(epo_reg_loss, self.get_setpoint4r())
):
self.writer.add_scalar(
f"lossrd/dyn_{list_str_multiplier_na[i]}", reg_dyn, miter
)
self.writer.add_scalar(
f"lossrs/setpoint_{list_str_multiplier_na[i]}", reg_set, miter
)

self.writer.add_scalars(
f"loss_rds/loss_{list_str_multiplier_na[i]}_w_setpoint",
{
f"lossr/loss_{list_str_multiplier_na[i]}": reg_dyn,
f"lossr/setpoint_{list_str_multiplier_na[i]}": reg_set,
},
miter,
)
self.writer.add_scalar(
f"x_ell_y_r/loss_{list_str_multiplier_na[i]}", reg_dyn, epo_task_loss
)
if list_str_multiplier_na:
for i, (reg_dyn, reg_set) in enumerate(
zip(epo_reg_loss, self.get_setpoint4r())
):

self.writer.add_scalar(
f"lossrd/dyn_{list_str_multiplier_na[i]}", reg_dyn, miter
)
self.writer.add_scalar(
f"lossrs/setpoint_{list_str_multiplier_na[i]}", reg_set, miter
)

self.writer.add_scalars(
f"loss_rds/loss_{list_str_multiplier_na[i]}_w_setpoint",
{
f"lossr/loss_{list_str_multiplier_na[i]}": reg_dyn,
f"lossr/setpoint_{list_str_multiplier_na[i]}": reg_set,
},
miter,
)
self.writer.add_scalar(
f"x_ell_y_r/loss_{list_str_multiplier_na[i]}", reg_dyn, epo_task_loss
)
else:
logger.info("No multiplier provided")
self.writer.add_scalar("loss_task/penalized", epo_loss_tr, miter)
self.writer.add_scalar("loss_task/ell", epo_task_loss, miter)
acc_te = 0
Expand Down
8 changes: 2 additions & 6 deletions domainlab/algos/trainers/train_fishr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TrainerFishr(TrainerBasic):
"Fishr: Invariant gradient variances for out-of-distribution generalization."
International Conference on Machine Learning. PMLR, 2022.
"""
def tr_epoch(self, epoch):
def tr_epoch(self, epoch, flag_info=False):
list_loaders = list(self.dict_loader_tr.values())
loaders_zip = zip(*list_loaders)
self.model.train()
Expand All @@ -46,7 +46,7 @@ def tr_epoch(self, epoch):
self.epo_loss_tr += loss.detach().item()
self.after_batch(epoch, ind_batch)

flag_stop = self.observer.update(epoch) # notify observer
flag_stop = self.observer.update(epoch, flag_info) # notify observer
return flag_stop

def var_grads_and_loss(self, tuple_data_domains_batch):
Expand Down Expand Up @@ -161,10 +161,6 @@ def cal_dict_variance_grads(self, tensor_x, vec_y):
inputs=list(self.model.parameters()), retain_graph=True, create_graph=True
)

for name, param in self.model.named_parameters():
print(name)
print(".grad.shape: ", param.variance.shape)

dict_variance = OrderedDict(
[(name, weights.variance.clone())
for name, weights in self.model.named_parameters()
Expand Down
2 changes: 1 addition & 1 deletion domainlab/algos/trainers/train_hyper_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def before_tr(self):
flag_update_epoch=True,
)

def tr_epoch(self, epoch):
def tr_epoch(self, epoch, flag_info=False):
"""
update hyper-parameters only per epoch
"""
Expand Down
4 changes: 2 additions & 2 deletions domainlab/algos/trainers/train_irm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TrainerIRM(TrainerBasic):
For more details, see section 3.2 and Appendix D of :
Arjovsky et al., “Invariant Risk Minimization.”
"""
def tr_epoch(self, epoch):
def tr_epoch(self, epoch, flag_info=False):
list_loaders = list(self.dict_loader_tr.values())
loaders_zip = zip(*list_loaders)
self.model.train()
Expand All @@ -46,7 +46,7 @@ def tr_epoch(self, epoch):
self.epo_loss_tr += loss.detach().item()
self.after_batch(epoch, ind_batch)

flag_stop = self.observer.update(epoch) # notify observer
flag_stop = self.observer.update(epoch, flag_info) # notify observer
return flag_stop

def _cal_phi(self, tensor_x):
Expand Down
2 changes: 1 addition & 1 deletion domainlab/algos/trainers/train_matchdg.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def tr_epoch(self, epoch, flag_info=False):

logger.info("\n\nPhase erm+ctr \n\n")
self.flag_erm = True
flag_stop = self.observer.update(epoch) # notify observer
flag_stop = self.observer.update(epoch, flag_info) # notify observer
return flag_stop

def tr_batch(self, epoch, batch_idx, x_e, y_e, d_e, others=None):
Expand Down
4 changes: 2 additions & 2 deletions domainlab/algos/trainers/train_mldg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def prepare_ziped_loader(self):
ddset_mix = DsetZip(ddset_source, ddset_target)
self.loader_tr_source_target = mk_loader(ddset_mix, self.aconf.bs)

def tr_epoch(self, epoch):
def tr_epoch(self, epoch, flag_info=False):
self.model.train()
self.epo_loss_tr = 0
self.prepare_ziped_loader()
Expand Down Expand Up @@ -117,5 +117,5 @@ def tr_epoch(self, epoch):
self.optimizer.step()
self.epo_loss_tr += loss.detach().item()
self.after_batch(epoch, ind_batch)
flag_stop = self.observer.update(epoch) # notify observer
flag_stop = self.observer.update(epoch, flag_info) # notify observer
return flag_stop
1 change: 0 additions & 1 deletion domainlab/models/a_model_classif.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

loss_cross_entropy_extended = extend(nn.CrossEntropyLoss(reduction="none"))


class AModelClassif(AModel, metaclass=abc.ABCMeta):
"""
operations that all classification model should have
Expand Down
27 changes: 26 additions & 1 deletion domainlab/models/model_erm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
except:
backpack = None


def mk_erm(parent_class=AModelClassif, **kwargs):
"""
Instantiate a Deepall (ERM) model
Expand Down Expand Up @@ -53,4 +52,30 @@ def convert4backpack(self):
"""
self._net_invar_feat = extend(self._net_invar_feat, use_converter=True)
self.net_classifier = extend(self.net_classifier, use_converter=True)

def hyper_update(self, epoch, fun_scheduler): # pylint: disable=unused-argument
"""
Method necessary to combine with hyperparameter scheduler

:param epoch:
:param fun_scheduler:
"""

def hyper_init(self, functor_scheduler, trainer=None):
"""
initiate a scheduler object via class name and things inside this model

:param functor_scheduler: the class name of the scheduler
"""
return functor_scheduler(
trainer=trainer
)

@property
def list_str_multiplier_na(self):
"""
list of multipliers which match the order in cal_reg_loss
"""
return []

return ModelERM
66 changes: 66 additions & 0 deletions examples/benchmark/pacs_fbopt_fishr_erm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
mode: grid

output_dir: zoutput/benchmarks/benchmark_fbopt_fishr_erm_pacs

sampling_seed: 0
startseed: 0
endseed: 0

test_domains:
- sketch

domainlab_args:
tpath: examples/tasks/task_pacs_aug.py
dmem: False
lr: 5e-5
epos: 10
epos_min: 2
es: 5
bs: 32
san_check: False
nname: alexnet
nname_dom: alexnet
zx_dim: 0
zy_dim: 64
zd_dim: 64




Shared params:
ini_setpoint_ratio:
min: 0.5
max: 0.99
num: 2
step: 0.05
distribution: uniform

k_i_gain:
min: 0.0001
max: 0.01
num: 2
step: 0.0001
distribution: uniform

gamma_reg:
min: 0.01
max: 1e4
num: 3
distribution: loguniform


# Test fbopt with different hyperparameter configurations

fbopt_fishr_erm:
model: erm
trainer: fbopt_fishr
shared:
- ini_setpoint_ratio
- k_i_gain
- gamma_reg

fishr_erm:
model: erm
trainer: fishr
shared:
- gamma_reg
5 changes: 3 additions & 2 deletions examples/tasks/task_pacs_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from domainlab.tasks.utils_task import ImSize

# change this to absolute directory where you have the raw images from PACS,
G_PACS_RAW_PATH = "domainlab/zdata/pacs/PACS"
G_PACS_RAW_PATH = "data/pacs/PACS"
# domainlab repository contain already the file names in
# domainlab/zdata/pacs_split folder of domainlab
# domainlab/zdata/pacs_split folder of domainlab,
# but PACS dataset is too big to put into domainlab folder


def get_task(na=None):
Expand Down
6 changes: 6 additions & 0 deletions tests/test_fbopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def test_diva_fbopt():
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=diva --gamma_y=1.0 --trainer=fbopt --nname=alexnet --epos=3"
utils_test_algo(args)

def test_erm_fbopt():
"""
erm
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt --nname=alexnet --epos=3" # pylint: disable=line-too-long
utils_test_algo(args)

def test_forcesetpoint_fbopt():
"""
Expand Down