Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import torch
import torch.nn as nn

from nemo.collections.common.parts.training_utils import avoid_float16_autocast_context
from nemo.utils import avoid_float16_autocast_context

__all__ = [
'RelPositionMultiHeadAttention',
Expand Down Expand Up @@ -337,7 +337,6 @@ def extend_pe(self, length, device):
# positive positions would be used for left positions and negative for right positions
positions = torch.arange(length - 1, -length, -1, dtype=torch.float32, device=device).unsqueeze(1)
self.create_pe(positions=positions)
self.center_pos = torch.tensor(self.pe.size(1) // 2 + 1, dtype=torch.int32, device=device)

def forward(self, x, cache_len=0):
"""Compute positional encoding.
Expand All @@ -356,8 +355,9 @@ def forward(self, x, cache_len=0):
# negative positions would be used for right and positive for left tokens
# for input of length L, 2*L-1 positions are needed, positions from (L-1) to -(L-1)
input_len = x.size(1) + cache_len
start_pos = self.center_pos - input_len
end_pos = self.center_pos + input_len - 1
center_pos = self.pe.size(1) // 2 + 1
start_pos = center_pos - input_len
end_pos = center_pos + input_len - 1
pos_emb = self.pe[:, start_pos:end_pos]
if self.dropout_emb:
pos_emb = self.dropout_emb(pos_emb)
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/parts/submodules/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.nn import LayerNorm

from nemo.collections.asr.parts.submodules.causal_convs import CausalConv2D
from nemo.collections.common.parts.training_utils import avoid_bfloat16_autocast_context
from nemo.utils import avoid_bfloat16_autocast_context


class StackingSubsampling(torch.nn.Module):
Expand Down
49 changes: 0 additions & 49 deletions nemo/collections/common/parts/training_utils.py

This file was deleted.

4 changes: 2 additions & 2 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c

if torch.is_tensor(expected):
tout = out.to('cpu')
logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}")
logging.debug(f"Checking output {i}, shape: {expected.shape}:\n")
this_good = True
try:
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance):
Expand All @@ -191,7 +191,7 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):

if torch.is_tensor(expected):
tout = torch.from_numpy(out)
logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}")
logging.debug(f"Checking output {i}, shape: {expected.shape}:\n")
this_good = True
try:
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance):
Expand Down