From 98e0af2bb4ddacde78bc4faa49b95598af81e31a Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 27 Jul 2023 11:09:52 +0800 Subject: [PATCH] fix: fix compute_approx_kl --- applications/Chat/coati/models/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py index 772bfc32982a..8769fb7a8c43 100644 --- a/applications/Chat/coati/models/utils.py +++ b/applications/Chat/coati/models/utils.py @@ -19,7 +19,7 @@ def compute_approx_kl(log_probs: torch.Tensor, action_mask: Mask for actions. """ - log_ratio = log_probs - log_probs_base + log_ratio = log_probs_base - log_probs approx_kl = (log_ratio.exp() - 1) - log_ratio if action_mask is not None: approx_kl = masked_mean(approx_kl, action_mask, dim=1)