Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions domainlab/algos/trainers/fbopt_mu_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, trainer, **kwargs):
self.mmu = {key: self.init_mu for key, val in self.mmu.items()}
self.set_point_controller = FbOptSetpointController(args=self.trainer.aconf)

self.k_i_control = trainer.aconf.k_i_gain
self.k_i_control = [trainer.aconf.k_i_gain for i in range(len(self.mmu))]
self.k_i_gain_ratio = None
self.overshoot_rewind = trainer.aconf.overshoot_rewind == "yes"
self.delta_epsilon_r = None
Expand Down Expand Up @@ -84,7 +84,7 @@ def set_k_i_gain(self, epo_reg_loss):
k_i_gain_saturate_min = min(k_i_gain_saturate)
# NOTE: here we override the commandline arguments specification
# for k_i_control, so k_i_control is not a hyperparameter anymore
self.k_i_control = self.k_i_gain_ratio * k_i_gain_saturate_min
self.k_i_control = [self.k_i_gain_ratio * ele for ele in k_i_gain_saturate]
warnings.warn(
f"hyperparameter k_i_gain disabled! \
replace with {self.k_i_control}"
Expand Down Expand Up @@ -162,7 +162,7 @@ def cal_activation(self):
"""
setpoint = self.get_setpoint4r()
activation = [
self.k_i_control * val if setpoint[i] > 0 else self.k_i_control * (-val)
self.k_i_control[i] * val if setpoint[i] > 0 else self.k_i_control[i] * (-val)
for i, val in enumerate(self.delta_epsilon_r)
]
if self.activation_clip is not None:
Expand Down