diff --git a/examples/nlp/token_classification/data/get_libritts_data.py b/examples/nlp/token_classification/data/get_libritts_data.py index 6b83b1fbaea2..86a5d01eb9dc 100644 --- a/examples/nlp/token_classification/data/get_libritts_data.py +++ b/examples/nlp/token_classification/data/get_libritts_data.py @@ -24,9 +24,9 @@ import subprocess import tarfile -from examples.nlp.token_classification.data.get_tatoeba_data import create_text_and_labels from tqdm import tqdm +from nemo.collections.nlp.data.token_classification.token_classification_utils import create_text_and_labels from nemo.utils import logging URL = { diff --git a/examples/nlp/token_classification/data/get_tatoeba_data.py b/examples/nlp/token_classification/data/get_tatoeba_data.py index 727848b550ab..6a4cd23b249d 100644 --- a/examples/nlp/token_classification/data/get_tatoeba_data.py +++ b/examples/nlp/token_classification/data/get_tatoeba_data.py @@ -17,9 +17,9 @@ import os import random import re -import string import subprocess +from nemo.collections.nlp.data.token_classification.token_classification_utils import create_text_and_labels from nemo.utils import logging URL = {'tatoeba': 'https://downloads.tatoeba.org/exports/sentences.csv'} @@ -120,68 +120,6 @@ def __split_into_train_dev(in_file: str, train_file: str, dev_file: str, percent dev_file.write(' '.join(lines[-dev_size:])) -def remove_punctuation(word: str): - """ - Removes all punctuation marks from a word except for ' - that is often a part of word: don't, it's, and so on - """ - all_punct_marks = string.punctuation.replace("'", '') - return re.sub('[' + all_punct_marks + ']', '', word) - - -def create_text_and_labels(output_dir: str, file_path: str, punct_marks: str = ',.?'): - """ - Create datasets for training and evaluation. - - Args: - output_dir: path to the output data directory - file_path: path to file name - punct_marks: supported punctuation marks - - The data will be split into 2 files: text.txt and labels.txt. \ - Each line of the text.txt file contains text sequences, where words\ - are separated with spaces. The labels.txt file contains \ - corresponding labels for each word in text.txt, the labels are \ - separated with spaces. Each line of the files should follow the \ - format: \ - [WORD] [SPACE] [WORD] [SPACE] [WORD] (for text.txt) and \ - [LABEL] [SPACE] [LABEL] [SPACE] [LABEL] (for labels.txt).' - """ - if not os.path.exists(file_path): - raise ValueError(f'{file_path} not found') - - os.makedirs(output_dir, exist_ok=True) - - base_name = os.path.basename(file_path) - labels_file = os.path.join(output_dir, 'labels_' + base_name) - text_file = os.path.join(output_dir, 'text_' + base_name) - - with open(file_path, 'r') as f: - with open(text_file, 'w') as text_f: - with open(labels_file, 'w') as labels_f: - for line in f: - line = line.split() - text = '' - labels = '' - for word in line: - label = word[-1] if word[-1] in punct_marks else 'O' - word = remove_punctuation(word) - if len(word) > 0: - if word[0].isupper(): - label += 'U' - else: - label += 'O' - - word = word.lower() - text += word + ' ' - labels += label + ' ' - - text_f.write(text.strip() + '\n') - labels_f.write(labels.strip() + '\n') - - print(f'{text_file} and {labels_file} created from {file_path}.') - - def __delete_file(file_to_del: str): """ Deletes the file diff --git a/nemo/collections/nlp/data/token_classification/token_classification_utils.py b/nemo/collections/nlp/data/token_classification/token_classification_utils.py index 828ef1180e0b..94acd69d3b11 100644 --- a/nemo/collections/nlp/data/token_classification/token_classification_utils.py +++ b/nemo/collections/nlp/data/token_classification/token_classification_utils.py @@ -14,6 +14,8 @@ import os import pickle +import re +import string from typing import Dict from nemo.collections.nlp.data.data_utils.data_preprocessing import ( @@ -23,7 +25,69 @@ ) from nemo.utils import logging -__all__ = ['get_label_ids'] +__all__ = ['get_label_ids', 'create_text_and_labels'] + + +def remove_punctuation(word: str): + """ + Removes all punctuation marks from a word except for ' + that is often a part of word: don't, it's, and so on + """ + all_punct_marks = string.punctuation.replace("'", '') + return re.sub('[' + all_punct_marks + ']', '', word) + + +def create_text_and_labels(output_dir: str, file_path: str, punct_marks: str = ',.?'): + """ + Create datasets for training and evaluation. + + Args: + output_dir: path to the output data directory + file_path: path to file name + punct_marks: supported punctuation marks + + The data will be split into 2 files: text.txt and labels.txt. \ + Each line of the text.txt file contains text sequences, where words\ + are separated with spaces. The labels.txt file contains \ + corresponding labels for each word in text.txt, the labels are \ + separated with spaces. Each line of the files should follow the \ + format: \ + [WORD] [SPACE] [WORD] [SPACE] [WORD] (for text.txt) and \ + [LABEL] [SPACE] [LABEL] [SPACE] [LABEL] (for labels.txt).' + """ + if not os.path.exists(file_path): + raise ValueError(f'{file_path} not found') + + os.makedirs(output_dir, exist_ok=True) + + base_name = os.path.basename(file_path) + labels_file = os.path.join(output_dir, 'labels_' + base_name) + text_file = os.path.join(output_dir, 'text_' + base_name) + + with open(file_path, 'r') as f: + with open(text_file, 'w') as text_f: + with open(labels_file, 'w') as labels_f: + for line in f: + line = line.split() + text = '' + labels = '' + for word in line: + label = word[-1] if word[-1] in punct_marks else 'O' + word = remove_punctuation(word) + if len(word) > 0: + if word[0].isupper(): + label += 'U' + else: + label += 'O' + + word = word.lower() + text += word + ' ' + labels += label + ' ' + + text_f.write(text.strip() + '\n') + labels_f.write(labels.strip() + '\n') + + print(f'{text_file} and {labels_file} created from {file_path}.') def get_label_ids( diff --git a/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb b/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb index fb544c24e0a2..4c20cae8af19 100644 --- a/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb +++ b/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb @@ -279,22 +279,23 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "outputs": [], "source": [ - "## download get_libritts_data.py script to download and preprocess the LibriTTS data\n", + "# download get_libritts_data.py script to download and preprocess the LibriTTS data\n", "os.makedirs(WORK_DIR, exist_ok=True)\n", "if not os.path.exists(WORK_DIR + '/get_libritts_data.py'):\n", " print('Downloading get_libritts_data.py...')\n", " wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/token_classification/data/get_libritts_data.py', WORK_DIR)\n", "else:\n", " print ('get_libritts_data.py already exists')" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } + ] }, { "cell_type": "code", @@ -996,9 +997,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.0" + "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 1 -} \ No newline at end of file +}