From 703e31952ec7af3ac23ff6b296d51e0c591dc04d Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 13 Feb 2025 12:41:14 +0000 Subject: [PATCH] Fix distributed inference --- kvpress/presses/expected_attention_press.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index febcd9dd..74b2a290 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -76,7 +76,7 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): R = cos.unsqueeze(1) * Id + sin.unsqueeze(1) * P # Apply average rotation to the mean and covariance - R = R.mean(dim=0) + R = R.mean(dim=0).to(mu.device) mu = torch.matmul(mu, R.T) if self.use_covariance: cov = torch.matmul(R, torch.matmul(cov, R.T))