diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index 78cf1ce37212..8f774e172718 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -140,6 +140,9 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next= """ key, value, query = self.update_cache(key=key, value=value, query=query, cache=cache, cache_next=cache_next) + if torch.is_autocast_enabled(): + query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) + # temporary until we solve this more gracefully with avoid_float16_autocast_context(): q, k, v = self.forward_qkv(query, key, value) @@ -217,6 +220,9 @@ def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None) """ key, value, query = self.update_cache(key=key, value=value, query=query, cache=cache, cache_next=cache_next) + if torch.is_autocast_enabled(): + query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) + # temporary until we solve this more gracefully with avoid_float16_autocast_context(): q, k, v = self.forward_qkv(query, key, value)