Skip to content
7 changes: 5 additions & 2 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ pipeline {
}



stage('L0: Create EN TN/ITN Grammars') {
when {
anyOf {
Expand All @@ -67,7 +66,11 @@ pipeline {
}
failFast true
parallel {

stage('L0: Test utils') {
steps {
sh 'CUDA_VISIBLE_DEVICES="" pytest tests/nemo_text_processing/audio_based_utils/ --cpu'
}
}
stage('L0: En TN grammars') {
steps {
sh 'CUDA_VISIBLE_DEVICES="" python nemo_text_processing/text_normalization/normalize.py --text="1" --cache_dir ${EN_TN_CACHE}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def normalize(
Returns:
normalized text options (usually there are multiple ways of normalizing a given semiotic class)
"""
if pred_text is None or self.tagger is None:
if pred_text is None or pred_text == "" or self.tagger is None:
return self.normalize_non_deterministic(
text=text, n_tagged=n_tagged, punct_post_process=punct_post_process, verbose=verbose
)
Expand All @@ -156,6 +156,7 @@ def normalize(
semiotic_spans, pred_text_spans, norm_spans, text_with_span_tags_list, masked_idx_list = get_alignment(
text, det_norm, pred_text, verbose=False
)

sem_tag_idx = 0
for cur_semiotic_span, cur_pred_text, cur_deter_norm in zip(semiotic_spans, pred_text_spans, norm_spans):
if len(cur_semiotic_span) == 0:
Expand Down
84 changes: 48 additions & 36 deletions nemo_text_processing/text_normalization/utils_audio_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

def _get_alignment(a: str, b: str) -> Dict:
"""

Construscts alignment between a and b
Constructs alignment between a and b

Returns:
a dictionary, where keys are a's word index and values is a Tuple that contains span from b, and whether it
Expand Down Expand Up @@ -62,7 +61,7 @@ def _get_alignment(a: str, b: str) -> Dict:

def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, norm: str, pred_text: str, verbose=False):
"""
Adjust alignement boundaries by taking norm--raw texts and norm--pred_text alignements, and creating raw-pred_text
Adjust alignment boundaries by taking norm--raw texts and norm--pred_text alignments, and creating raw-pred_text alignment
alignment.

norm_raw_diffs: output of _get_alignment(norm, raw)
Expand Down Expand Up @@ -92,10 +91,12 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
raw_text_mask_idx: [1, 4]
"""

adjusted = []
raw_pred_spans = []
word_id = 0
while word_id < len(norm.split()):
norm_raw, norm_pred = norm_raw_diffs[word_id], norm_pred_diffs[word_id]
# if there is a mismatch in norm_raw and norm_pred, expand the boundaries of the shortest mismatch to align with the longest one
# e.g., norm_raw = (1, 2, 'match') norm_pred = (1, 5, 'non-match') => expand norm_raw until the next matching sequence or the end of string to align with norm_pred
if (norm_raw[2] == MATCH and norm_pred[2] == NONMATCH) or (norm_raw[2] == NONMATCH and norm_pred[2] == MATCH):
mismatched_id = word_id
non_match_raw_start = norm_raw[0]
Expand All @@ -114,20 +115,21 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
if not done:
non_match_raw_end = len(raw.split())
non_match_pred_end = len(pred_text.split())
adjusted.append(
raw_pred_spans.append(
(
mismatched_id,
(non_match_raw_start, non_match_raw_end, NONMATCH),
(non_match_pred_start, non_match_pred_end, NONMATCH),
)
)
else:
adjusted.append((word_id, norm_raw, norm_pred))
raw_pred_spans.append((word_id, norm_raw, norm_pred))
word_id += 1

adjusted2 = []
# aggregate neighboring spans with the same status
spans_merged_neighbors = []
last_status = None
for idx, item in enumerate(adjusted):
for idx, item in enumerate(raw_pred_spans):
if last_status is None:
last_status = item[1][2]
raw_start = item[1][0]
Expand All @@ -139,7 +141,7 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
raw_end = item[1][1]
pred_text_end = item[2][1]
else:
adjusted2.append(
spans_merged_neighbors.append(
[[norm_span_start, item[0]], [raw_start, raw_end], [pred_text_start, pred_text_end], last_status]
)
last_status = item[1][2]
Expand All @@ -152,13 +154,13 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
if last_status == item[1][2]:
raw_end = item[1][1]
pred_text_end = item[2][1]
adjusted2.append(
spans_merged_neighbors.append(
[[norm_span_start, item[0]], [raw_start, raw_end], [pred_text_start, pred_text_end], last_status]
)
else:
adjusted2.append(
spans_merged_neighbors.append(
[
[adjusted[idx - 1][0], len(norm.split())],
[raw_pred_spans[idx - 1][0], len(norm.split())],
[item[1][0], len(raw.split())],
[item[2][0], len(pred_text.split())],
item[1][2],
Expand All @@ -171,10 +173,10 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor

# increase boundaries between raw and pred_text if some spans contain empty pred_text
extended_spans = []
adjusted3 = []
raw_norm_spans_corrected_for_pred_text = []
idx = 0
while idx < len(adjusted2):
item = adjusted2[idx]
while idx < len(spans_merged_neighbors):
item = spans_merged_neighbors[idx]

cur_semiotic = " ".join(raw_list[item[1][0] : item[1][1]])
cur_pred_text = " ".join(pred_text_list[item[2][0] : item[2][1]])
Expand All @@ -186,8 +188,8 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
# if cur_pred_text is an empty string
if item[2][0] == item[2][1]:
# for the last item
if idx == len(adjusted2) - 1 and len(adjusted3) > 0:
last_item = adjusted3[-1]
if idx == len(spans_merged_neighbors) - 1 and len(raw_norm_spans_corrected_for_pred_text) > 0:
last_item = raw_norm_spans_corrected_for_pred_text[-1]
last_item[0][1] = item[0][1]
last_item[1][1] = item[1][1]
last_item[2][1] = item[2][1]
Expand All @@ -196,29 +198,31 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
raw_start, raw_end = item[0]
norm_start, norm_end = item[1]
pred_start, pred_end = item[2]
while idx < len(adjusted2) - 1 and not ((pred_end - pred_start) > 2 and adjusted2[idx][-1] == MATCH):
while idx < len(spans_merged_neighbors) - 1 and not (
(pred_end - pred_start) > 2 and spans_merged_neighbors[idx][-1] == MATCH
):
idx += 1
raw_end = adjusted2[idx][0][1]
norm_end = adjusted2[idx][1][1]
pred_end = adjusted2[idx][2][1]
raw_end = spans_merged_neighbors[idx][0][1]
norm_end = spans_merged_neighbors[idx][1][1]
pred_end = spans_merged_neighbors[idx][2][1]
cur_item = [[raw_start, raw_end], [norm_start, norm_end], [pred_start, pred_end], NONMATCH]
adjusted3.append(cur_item)
extended_spans.append(len(adjusted3) - 1)
raw_norm_spans_corrected_for_pred_text.append(cur_item)
extended_spans.append(len(raw_norm_spans_corrected_for_pred_text) - 1)
idx += 1
else:
adjusted3.append(item)
raw_norm_spans_corrected_for_pred_text.append(item)
idx += 1

semiotic_spans = []
norm_spans = []
pred_texts = []
raw_text_masked = ""
for idx, item in enumerate(adjusted3):
for idx, item in enumerate(raw_norm_spans_corrected_for_pred_text):
cur_semiotic = " ".join(raw_list[item[1][0] : item[1][1]])
cur_pred_text = " ".join(pred_text_list[item[2][0] : item[2][1]])
cur_norm_span = " ".join(norm_list[item[0][0] : item[0][1]])

if idx == len(adjusted3) - 1:
if idx == len(raw_norm_spans_corrected_for_pred_text) - 1:
cur_norm_span = " ".join(norm_list[item[0][0] : len(norm_list)])
if (item[-1] == NONMATCH and cur_semiotic != cur_norm_span) or (idx in extended_spans):
raw_text_masked += " " + SEMIOTIC_TAG
Expand All @@ -233,24 +237,31 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor

if verbose:
print("+" * 50)
print("adjusted:")
for item in adjusted2:
print("raw_pred_spans:")
for item in spans_merged_neighbors:
print(f"{raw.split()[item[1][0]: item[1][1]]} -- {pred_text.split()[item[2][0]: item[2][1]]}")

print("+" * 50)
print("adjusted2:")
for item in adjusted2:
print("spans_merged_neighbors:")
for item in spans_merged_neighbors:
print(f"{raw.split()[item[1][0]: item[1][1]]} -- {pred_text.split()[item[2][0]: item[2][1]]}")
print("+" * 50)
print("adjusted3:")
for item in adjusted3:
print("raw_norm_spans_corrected_for_pred_text:")
for item in raw_norm_spans_corrected_for_pred_text:
print(f"{raw.split()[item[1][0]: item[1][1]]} -- {pred_text.split()[item[2][0]: item[2][1]]}")
print("+" * 50)

return semiotic_spans, pred_texts, norm_spans, raw_text_masked_list, raw_text_mask_idx


def get_alignment(raw, norm, pred_text, verbose: bool = False):
def get_alignment(raw: str, norm: str, pred_text: str, verbose: bool = False):
"""
Aligns raw text with deterministically normalized text and ASR output, finds semiotic spans
"""
for value in [raw, norm, pred_text]:
if value is None or value == "":
return [], [], [], [], []

norm_pred_diffs = _get_alignment(norm, pred_text)
norm_raw_diffs = _get_alignment(norm, raw)

Expand All @@ -271,8 +282,9 @@ def get_alignment(raw, norm, pred_text, verbose: bool = False):


if __name__ == "__main__":
raw = 'This is #4 ranking on G.S.K.T.'
pred_text = 'this iss for ranking on g k p'
raw = 'This is a #4 ranking on G.S.K.T.'
pred_text = 'this iss p k for ranking on g k p'
norm = 'This is nubmer four ranking on GSKT'

get_alignment(raw, norm, pred_text, True)
output = get_alignment(raw, norm, pred_text, True)
print(output)
13 changes: 13 additions & 0 deletions tests/nemo_text_processing/audio_based_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from nemo_text_processing.text_normalization.utils_audio_based import get_alignment


class TestAudioBasedTNUtils:
@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_default(self):
raw = 'This is #4 ranking on G.S.K.T.'
pred_text = 'this iss for ranking on g k p'
norm = 'This is nubmer four ranking on GSKT'

output = get_alignment(raw, norm, pred_text, True)
reference = (
['is #4', 'G.S.K.T.'],
['iss for', 'g k p'],
['is nubmer four', 'GSKT'],
['This', '[SEMIOTIC_SPAN]', 'ranking', 'on', '[SEMIOTIC_SPAN]'],
[1, 4],
)
assert output == reference