diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index febcd9dd..74b2a290 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -76,7 +76,7 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): R = cos.unsqueeze(1) * Id + sin.unsqueeze(1) * P # Apply average rotation to the mean and covariance - R = R.mean(dim=0) + R = R.mean(dim=0).to(mu.device) mu = torch.matmul(mu, R.T) if self.use_covariance: cov = torch.matmul(R, torch.matmul(cov, R.T))