From 7ceb45d73a17ee6c596fea03bdc16d5b6a55fe9f Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Thu, 10 Nov 2022 09:52:33 -0800 Subject: [PATCH] Force MHA QKV onto fp32 Signed-off-by: smajumdar --- .../asr/parts/submodules/multi_head_attention.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index 78cf1ce37212..8f774e172718 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -140,6 +140,9 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next= """ key, value, query = self.update_cache(key=key, value=value, query=query, cache=cache, cache_next=cache_next) + if torch.is_autocast_enabled(): + query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) + # temporary until we solve this more gracefully with avoid_float16_autocast_context(): q, k, v = self.forward_qkv(query, key, value) @@ -217,6 +220,9 @@ def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None) """ key, value, query = self.update_cache(key=key, value=value, query=query, cache=cache, cache_next=cache_next) + if torch.is_autocast_enabled(): + query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) + # temporary until we solve this more gracefully with avoid_float16_autocast_context(): q, k, v = self.forward_qkv(query, key, value)