-
Notifications
You must be signed in to change notification settings - Fork 3
Mhof dev vector ki gain: copy changeds from https://github.com/marrlab/DomainLab/pull/772/files mannually #880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b2b204d
ca9766e
731a144
9e09396
b4ddafe
d53c863
f15478c
6626475
d5d3a0b
6701cbf
888d714
64bcc9c
22f343c
23607fb
28bdfa9
63ad47b
8ba7e11
6073403
8bd1163
7667cbc
795478f
b64e5d0
6433f58
6f5bac1
ba68e27
c09879d
401e978
5585b4a
d176e1a
3a4226e
011cccb
ae03219
3166eb2
d8cee2d
9884640
cd15915
273422b
dbbda55
1cb2428
155f0df
c41e92b
b8731cb
f4e4773
e3ed8c6
e38fa75
044e3cc
aea6e90
e5197f3
dc58a50
9ddd037
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=hyperscheduler_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=10 --epos_min=4 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.99999999 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| python main_out.py --te_d=0 --task=mnistcolor10 --model=erm --trainer=fbopt_irm --nname=conv_bn_pool_2 --k_i_gain_ratio=0.5 | ||
| python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=fbopt_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 --epos_min=400 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.9 --nb4reg_over_task_ratio=0 --tr_with_init_mu --coeff_ma_setpoint=0.0 --str_setpoint_ada="SliderAnyComponent()" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,6 +53,7 @@ def __init__(self, successor_node=None, extend=None): | |
| """ | ||
| super().__init__(successor_node) | ||
| self._model = None | ||
| # decoratee can be both model or trainer | ||
| self._decoratee = extend | ||
| self.task = None | ||
| self.observer = None | ||
|
|
@@ -96,6 +97,8 @@ def __init__(self, successor_node=None, extend=None): | |
| self._ma_iter = 0 | ||
| # | ||
| self.list_reg_over_task_ratio = None | ||
| # mhof | ||
| self.dict_multiplier = {} | ||
|
|
||
|
|
||
| @property | ||
|
|
@@ -199,7 +202,13 @@ def before_tr(self): | |
| """ | ||
| before training, probe model performance | ||
| """ | ||
| self.cal_reg_loss_over_task_loss_ratio() | ||
| list_mu = self.cal_reg_loss_over_task_loss_ratio() | ||
| self.dict_multiplier = {'mu4regloss'+str(i): value for i, value in enumerate(list_mu)} | ||
|
|
||
| @property | ||
| def list_str_multiplier_na(self): | ||
| list_str = list(self.dict_multiplier.keys()) | ||
| return list_str | ||
|
|
||
| def cal_reg_loss_over_task_loss_ratio(self): | ||
| """ | ||
|
|
@@ -208,18 +217,21 @@ def cal_reg_loss_over_task_loss_ratio(self): | |
| """ | ||
| list_accum_reg_loss = [] | ||
| loss_task_agg = 0 | ||
| list_mu = None | ||
| for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( | ||
| self.loader_tr | ||
| ): | ||
| if ind_batch >= self.aconf.nb4reg_over_task_ratio: | ||
| return | ||
| tensor_x, tensor_y, tensor_d = ( | ||
| tensor_x.to(self.device), | ||
| tensor_y.to(self.device), | ||
| tensor_d.to(self.device), | ||
| ) | ||
| list_reg_loss_tensor, _ = \ | ||
| list_reg_loss_tensor, list_mu = \ | ||
| self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) | ||
|
|
||
| if ind_batch >= self.aconf.nb4reg_over_task_ratio: | ||
| return list_mu | ||
|
|
||
| list_reg_loss_tensor = [torch.sum(tensor).detach().item() | ||
| for tensor in list_reg_loss_tensor] | ||
| if ind_batch == 0: | ||
|
|
@@ -235,6 +247,7 @@ def cal_reg_loss_over_task_loss_ratio(self): | |
| loss_task_agg += tensor_loss_task | ||
| self.list_reg_over_task_ratio = [reg_loss / loss_task_agg | ||
| for reg_loss in list_accum_reg_loss] | ||
| return list_mu | ||
|
|
||
| def post_tr(self): | ||
| """ | ||
|
|
@@ -301,7 +314,25 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): | |
| list_loss_tensor = list_reg_loss_model_tensor + \ | ||
| list_reg_loss_trainer_tensor | ||
| list_mu = list_mu_model + list_mu_trainer | ||
| return list_loss_tensor, list_mu | ||
| # ERM return a tensor of all zeros, delete here | ||
| if len(list_mu) > 1: | ||
| list_boolean_zero = [torch.all(torch.eq(list_loss_tensor[i], 0)).item() | ||
| for i in range(len(list_mu))] | ||
| list_loss_tensor = [list_loss_tensor[i] for (i, flag) in | ||
| enumerate(list_boolean_zero) if not flag] | ||
| list_mu = [list_mu[i] for (i, flag) in enumerate(list_boolean_zero) if not flag] | ||
| if self.dict_multiplier: | ||
| list_mu = list(self.dict_multiplier.values()) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is bad implementation, will list(self.dict_multiplier.values) return the same list order each time? since mhof is updating the self.dict_multiplier, as long as the order each time is the same (even if the order does not agree with the order list_mu) returns. Alexej suggest to use OrderDictionary
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OrderedDict preserves the order in which items are inserte, so the same as dicitonary here
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. list(dict.keys()) will return keys in the same order they were inserted each time it's called, as long as the dictionary hasn't been modified. |
||
|
|
||
| list_loss_tensor_normalized = list_loss_tensor | ||
| if self.list_reg_over_task_ratio: | ||
| assert len(list_mu) == len(self.list_reg_over_task_ratio) | ||
| list_loss_tensor_normalized = \ | ||
| [reg_loss / reg_over_task_ratio if reg_over_task_ratio != 0 | ||
| else reg_loss for (reg_loss, reg_over_task_ratio) | ||
| in zip(list_loss_tensor, self.list_reg_over_task_ratio)] | ||
|
|
||
| return list_loss_tensor_normalized, list_mu | ||
|
|
||
| def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): | ||
| """ | ||
|
|
@@ -326,3 +357,23 @@ def print_parameters(self): | |
| """ | ||
| params = vars(self) | ||
| print(f"Parameters of {type(self).__name__}: {params}") | ||
|
|
||
| def hyper_init(self, functor_scheduler, trainer): | ||
| """ | ||
| initialize both trainer's multiplier and model's multiplier | ||
| """ | ||
| if not self.dict_multiplier: | ||
| raise RuntimeError("self.dict_multiplier empty!") | ||
| return functor_scheduler( | ||
| trainer=trainer, **self.dict_multiplier | ||
| ) | ||
|
|
||
| def hyper_update(self, epoch, fun_scheduler): | ||
| """hyper_update. | ||
|
|
||
| :param epoch: | ||
| :param fun_scheduler: | ||
| """ | ||
| dict_rst = fun_scheduler(epoch) | ||
| for key in self.dict_multiplier: | ||
| self.dict_multiplier[key] = dict_rst[key] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let each trainer to return string of reg loss name in _cal_reg_loss