diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index 9f8e02971..9ae93a9bc 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -50,7 +50,7 @@ def __init__(self, trainer, **kwargs): self.mmu = {key: self.init_mu for key, val in self.mmu.items()} self.set_point_controller = FbOptSetpointController(args=self.trainer.aconf) - self.k_i_control = trainer.aconf.k_i_gain + self.k_i_control = [trainer.aconf.k_i_gain for i in range(len(self.mmu))] self.k_i_gain_ratio = None self.overshoot_rewind = trainer.aconf.overshoot_rewind == "yes" self.delta_epsilon_r = None @@ -84,7 +84,7 @@ def set_k_i_gain(self, epo_reg_loss): k_i_gain_saturate_min = min(k_i_gain_saturate) # NOTE: here we override the commandline arguments specification # for k_i_control, so k_i_control is not a hyperparameter anymore - self.k_i_control = self.k_i_gain_ratio * k_i_gain_saturate_min + self.k_i_control = [self.k_i_gain_ratio * ele for ele in k_i_gain_saturate] warnings.warn( f"hyperparameter k_i_gain disabled! \ replace with {self.k_i_control}" @@ -162,7 +162,7 @@ def cal_activation(self): """ setpoint = self.get_setpoint4r() activation = [ - self.k_i_control * val if setpoint[i] > 0 else self.k_i_control * (-val) + self.k_i_control[i] * val if setpoint[i] > 0 else self.k_i_control[i] * (-val) for i, val in enumerate(self.delta_epsilon_r) ] if self.activation_clip is not None: