Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
import subprocess
import tarfile

from examples.nlp.token_classification.data.get_tatoeba_data import create_text_and_labels
from tqdm import tqdm

from nemo.collections.nlp.data.token_classification.token_classification_utils import create_text_and_labels
from nemo.utils import logging

URL = {
Expand Down
64 changes: 1 addition & 63 deletions examples/nlp/token_classification/data/get_tatoeba_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import os
import random
import re
import string
import subprocess

from nemo.collections.nlp.data.token_classification.token_classification_utils import create_text_and_labels
from nemo.utils import logging

URL = {'tatoeba': 'https://downloads.tatoeba.org/exports/sentences.csv'}
Expand Down Expand Up @@ -120,68 +120,6 @@ def __split_into_train_dev(in_file: str, train_file: str, dev_file: str, percent
dev_file.write(' '.join(lines[-dev_size:]))


def remove_punctuation(word: str):
"""
Removes all punctuation marks from a word except for '
that is often a part of word: don't, it's, and so on
"""
all_punct_marks = string.punctuation.replace("'", '')
return re.sub('[' + all_punct_marks + ']', '', word)


def create_text_and_labels(output_dir: str, file_path: str, punct_marks: str = ',.?'):
"""
Create datasets for training and evaluation.

Args:
output_dir: path to the output data directory
file_path: path to file name
punct_marks: supported punctuation marks

The data will be split into 2 files: text.txt and labels.txt. \
Each line of the text.txt file contains text sequences, where words\
are separated with spaces. The labels.txt file contains \
corresponding labels for each word in text.txt, the labels are \
separated with spaces. Each line of the files should follow the \
format: \
[WORD] [SPACE] [WORD] [SPACE] [WORD] (for text.txt) and \
[LABEL] [SPACE] [LABEL] [SPACE] [LABEL] (for labels.txt).'
"""
if not os.path.exists(file_path):
raise ValueError(f'{file_path} not found')

os.makedirs(output_dir, exist_ok=True)

base_name = os.path.basename(file_path)
labels_file = os.path.join(output_dir, 'labels_' + base_name)
text_file = os.path.join(output_dir, 'text_' + base_name)

with open(file_path, 'r') as f:
with open(text_file, 'w') as text_f:
with open(labels_file, 'w') as labels_f:
for line in f:
line = line.split()
text = ''
labels = ''
for word in line:
label = word[-1] if word[-1] in punct_marks else 'O'
word = remove_punctuation(word)
if len(word) > 0:
if word[0].isupper():
label += 'U'
else:
label += 'O'

word = word.lower()
text += word + ' '
labels += label + ' '

text_f.write(text.strip() + '\n')
labels_f.write(labels.strip() + '\n')

print(f'{text_file} and {labels_file} created from {file_path}.')


def __delete_file(file_to_del: str):
"""
Deletes the file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import os
import pickle
import re
import string
from typing import Dict

from nemo.collections.nlp.data.data_utils.data_preprocessing import (
Expand All @@ -23,7 +25,69 @@
)
from nemo.utils import logging

__all__ = ['get_label_ids']
__all__ = ['get_label_ids', 'create_text_and_labels']


def remove_punctuation(word: str):
"""
Removes all punctuation marks from a word except for '
that is often a part of word: don't, it's, and so on
"""
all_punct_marks = string.punctuation.replace("'", '')
return re.sub('[' + all_punct_marks + ']', '', word)


def create_text_and_labels(output_dir: str, file_path: str, punct_marks: str = ',.?'):
"""
Create datasets for training and evaluation.

Args:
output_dir: path to the output data directory
file_path: path to file name
punct_marks: supported punctuation marks

The data will be split into 2 files: text.txt and labels.txt. \
Each line of the text.txt file contains text sequences, where words\
are separated with spaces. The labels.txt file contains \
corresponding labels for each word in text.txt, the labels are \
separated with spaces. Each line of the files should follow the \
format: \
[WORD] [SPACE] [WORD] [SPACE] [WORD] (for text.txt) and \
[LABEL] [SPACE] [LABEL] [SPACE] [LABEL] (for labels.txt).'
"""
if not os.path.exists(file_path):
raise ValueError(f'{file_path} not found')

os.makedirs(output_dir, exist_ok=True)

base_name = os.path.basename(file_path)
labels_file = os.path.join(output_dir, 'labels_' + base_name)
text_file = os.path.join(output_dir, 'text_' + base_name)

with open(file_path, 'r') as f:
with open(text_file, 'w') as text_f:
with open(labels_file, 'w') as labels_f:
for line in f:
line = line.split()
text = ''
labels = ''
for word in line:
label = word[-1] if word[-1] in punct_marks else 'O'
word = remove_punctuation(word)
if len(word) > 0:
if word[0].isupper():
label += 'U'
else:
label += 'O'

word = word.lower()
text += word + ' '
labels += label + ' '

text_f.write(text.strip() + '\n')
labels_f.write(labels.strip() + '\n')

print(f'{text_file} and {labels_file} created from {file_path}.')


def get_label_ids(
Expand Down
23 changes: 12 additions & 11 deletions tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -279,22 +279,23 @@
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"outputs": [],
"source": [
"## download get_libritts_data.py script to download and preprocess the LibriTTS data\n",
"# download get_libritts_data.py script to download and preprocess the LibriTTS data\n",
"os.makedirs(WORK_DIR, exist_ok=True)\n",
"if not os.path.exists(WORK_DIR + '/get_libritts_data.py'):\n",
" print('Downloading get_libritts_data.py...')\n",
" wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/token_classification/data/get_libritts_data.py', WORK_DIR)\n",
"else:\n",
" print ('get_libritts_data.py already exists')"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
]
},
{
"cell_type": "code",
Expand Down Expand Up @@ -996,9 +997,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
}