From 3b4d4633cd085fb5b0e7cc0eea3415dfe279c0de Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Thu, 10 Nov 2022 15:51:55 -0800 Subject: [PATCH] Force MHA QKV onto fp32 (#5391) Signed-off-by: smajumdar Signed-off-by: smajumdar --- .../collections/asr/parts/submodules/multi_head_attention.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index acbee5bf7df5..62206fb7d3da 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -142,6 +142,8 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next= if torch.is_autocast_enabled(): query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) + + # temporary until we solve this more gracefully with avoid_float16_autocast_context(): q, k, v = self.forward_qkv(query, key, value) scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k @@ -218,6 +220,9 @@ def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None) """ key, value, query = self.update_cache(key=key, value=value, query=query, cache=cache, cache_next=cache_next) + if torch.is_autocast_enabled(): + query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) + # temporary until we solve this more gracefully with avoid_float16_autocast_context(): q, k, v = self.forward_qkv(query, key, value)