diff --git a/domainlab/algos/trainers/train_fbopt_b.py b/domainlab/algos/trainers/train_fbopt_b.py index 7258453ba..1efe3ce58 100644 --- a/domainlab/algos/trainers/train_fbopt_b.py +++ b/domainlab/algos/trainers/train_fbopt_b.py @@ -166,7 +166,10 @@ def tr_epoch(self, epoch, flag_info=False): logger = Logger.get_logger() logger.info(f"current multiplier: {self.model.dict_multiplier}") - flag = super().tr_epoch(epoch, self.flag_setpoint_updated) + if self._decoratee is not None: + flag = self._decoratee.tr_epoch(epoch, self.flag_setpoint_updated) + else: + flag = super().tr_epoch(epoch, self.flag_setpoint_updated) # is it good to update setpoint after we know the new value of each loss? self.flag_setpoint_updated = self.hyper_scheduler.update_setpoint( self.epo_reg_loss_tr, self.epo_task_loss_tr diff --git a/domainlab/algos/trainers/train_matchdg.py b/domainlab/algos/trainers/train_matchdg.py index 6a3edd996..cfe8e5a66 100644 --- a/domainlab/algos/trainers/train_matchdg.py +++ b/domainlab/algos/trainers/train_matchdg.py @@ -42,7 +42,7 @@ def init_business( self.tuple_tensor_ref_domain2each_y = None self.tuple_tensor_refdomain2each = None - def tr_epoch(self, epoch): + def tr_epoch(self, epoch, flag_info=False): """ # data in one batch comes from two sources: one part from loader, # the other part from match tensor diff --git a/run_fbopt_match_diva.sh b/run_fbopt_match_diva.sh new file mode 100644 index 000000000..c1547567c --- /dev/null +++ b/run_fbopt_match_diva.sh @@ -0,0 +1,6 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES="" +# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error +# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring +# pytest -s tests/test_fbopt.py +python main_out.py --te_d=caltech --task=mini_vlcs --bs=8 --model=diva --trainer=fbopt_matchdg --nname=alexnet --nname_dom=alexnet --gamma_d=3 --gamma_y=3 --epos=200 --es=100