From a7143301fb510b1a08f2603f10d5109c627a6ecf Mon Sep 17 00:00:00 2001 From: dssrgu Date: Fri, 26 Feb 2021 14:06:46 +0900 Subject: [PATCH] goa: torch 1.7 compatible version --- d4rl/rlkit/torch/sac/cql.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/d4rl/rlkit/torch/sac/cql.py b/d4rl/rlkit/torch/sac/cql.py index 98a533e..773262d 100644 --- a/d4rl/rlkit/torch/sac/cql.py +++ b/d4rl/rlkit/torch/sac/cql.py @@ -142,9 +142,9 @@ def _get_policy_actions(self, obs, num_actions, network=None): obs_temp, reparameterize=True, return_log_prob=True, ) if not self.discrete: - return new_obs_actions, new_obs_log_pi.view(obs.shape[0], num_actions, 1) + return new_obs_actions.detach(), new_obs_log_pi.view(obs.shape[0], num_actions, 1).detach() else: - return new_obs_actions + return new_obs_actions.detach() def train_from_torch(self, batch): self._current_epoch += 1 @@ -283,6 +283,11 @@ def train_from_torch(self, batch): """ Update networks """ + self._num_policy_update_steps += 1 + self.policy_optimizer.zero_grad() + policy_loss.backward(retain_graph=False) + self.policy_optimizer.step() + # Update the Q-functions iff self._num_q_update_steps += 1 self.qf1_optimizer.zero_grad() @@ -294,11 +299,6 @@ def train_from_torch(self, batch): qf2_loss.backward(retain_graph=True) self.qf2_optimizer.step() - self._num_policy_update_steps += 1 - self.policy_optimizer.zero_grad() - policy_loss.backward(retain_graph=False) - self.policy_optimizer.step() - """ Soft Updates """