diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 7f652961..7aabfbba 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -138,7 +138,7 @@ def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: return x def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - attention_scores = torch.matmul(query, key.transpose(-1, -2)) + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale attention_probs = attention_scores.softmax(dim=-1) # compute attention output hidden_states = torch.matmul(attention_probs, value)