From c6f217eabb2b7d93241666e033fea4d2872a8394 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 23 Nov 2022 14:04:38 -0800 Subject: [PATCH 1/2] Export fixes for Riva Signed-off-by: Boris Fomitchev --- .../asr/parts/submodules/multi_head_attention.py | 8 ++++---- nemo/utils/export_utils.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index 8f774e172718..62206fb7d3da 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -37,7 +37,7 @@ import torch import torch.nn as nn -from nemo.collections.common.parts.training_utils import avoid_float16_autocast_context +from nemo.utils import avoid_float16_autocast_context __all__ = [ 'RelPositionMultiHeadAttention', @@ -337,7 +337,6 @@ def extend_pe(self, length, device): # positive positions would be used for left positions and negative for right positions positions = torch.arange(length - 1, -length, -1, dtype=torch.float32, device=device).unsqueeze(1) self.create_pe(positions=positions) - self.center_pos = torch.tensor(self.pe.size(1) // 2 + 1, dtype=torch.int32, device=device) def forward(self, x, cache_len=0): """Compute positional encoding. @@ -356,8 +355,9 @@ def forward(self, x, cache_len=0): # negative positions would be used for right and positive for left tokens # for input of length L, 2*L-1 positions are needed, positions from (L-1) to -(L-1) input_len = x.size(1) + cache_len - start_pos = self.center_pos - input_len - end_pos = self.center_pos + input_len - 1 + center_pos = self.pe.size(1) // 2 + 1 + start_pos = center_pos - input_len + end_pos = center_pos + input_len - 1 pos_emb = self.pe[:, start_pos:end_pos] if self.dropout_emb: pos_emb = self.dropout_emb(pos_emb) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index e4fda73c181d..c89644fd1e68 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -169,7 +169,7 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c if torch.is_tensor(expected): tout = out.to('cpu') - logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") + logging.debug(f"Checking output {i}, shape: {expected.shape}:\n") this_good = True try: if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance): @@ -191,7 +191,7 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): if torch.is_tensor(expected): tout = torch.from_numpy(out) - logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") + logging.debug(f"Checking output {i}, shape: {expected.shape}:\n") this_good = True try: if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): From dffeb03611b4003e0525dae48f1930c2393260f6 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 23 Nov 2022 15:12:10 -0800 Subject: [PATCH 2/2] Cleaning up training_utils Signed-off-by: Boris Fomitchev --- .../asr/parts/submodules/subsampling.py | 2 +- .../common/parts/training_utils.py | 49 ------------------- 2 files changed, 1 insertion(+), 50 deletions(-) delete mode 100644 nemo/collections/common/parts/training_utils.py diff --git a/nemo/collections/asr/parts/submodules/subsampling.py b/nemo/collections/asr/parts/submodules/subsampling.py index ff5b2b2e4686..06bce6bb3d0f 100644 --- a/nemo/collections/asr/parts/submodules/subsampling.py +++ b/nemo/collections/asr/parts/submodules/subsampling.py @@ -19,7 +19,7 @@ from torch.nn import LayerNorm from nemo.collections.asr.parts.submodules.causal_convs import CausalConv2D -from nemo.collections.common.parts.training_utils import avoid_bfloat16_autocast_context +from nemo.utils import avoid_bfloat16_autocast_context class StackingSubsampling(torch.nn.Module): diff --git a/nemo/collections/common/parts/training_utils.py b/nemo/collections/common/parts/training_utils.py deleted file mode 100644 index 4b4b1c3c910b..000000000000 --- a/nemo/collections/common/parts/training_utils.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from contextlib import nullcontext - -import torch - -__all__ = ['avoid_bfloat16_autocast_context', 'avoid_float16_autocast_context'] - - -def avoid_bfloat16_autocast_context(): - """ - If the current autocast context is bfloat16, - cast it to float32 - """ - - if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16: - return torch.cuda.amp.autocast(dtype=torch.float32) - else: - return nullcontext() - - -def avoid_float16_autocast_context(): - """ - If the current autocast context is float16, cast it to bfloat16 - if available (unless we're in jit) or float32 - """ - - if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return torch.cuda.amp.autocast(dtype=torch.float32) - - if torch.cuda.is_bf16_supported(): - return torch.cuda.amp.autocast(dtype=torch.bfloat16) - else: - return torch.cuda.amp.autocast(dtype=torch.float32) - else: - return nullcontext()