diff --git a/nemo/collections/nlp/data/token_classification/punctuation_capitalization_tarred_dataset.py b/nemo/collections/nlp/data/token_classification/punctuation_capitalization_tarred_dataset.py index 363c1a0a1e9d..2bfcb7969b6e 100644 --- a/nemo/collections/nlp/data/token_classification/punctuation_capitalization_tarred_dataset.py +++ b/nemo/collections/nlp/data/token_classification/punctuation_capitalization_tarred_dataset.py @@ -19,6 +19,7 @@ import pickle import re import shutil +import tempfile from collections import deque from pathlib import Path from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union @@ -160,36 +161,44 @@ def process_fragment( special_tokens=special_tokens, use_fast=use_fast_tokenizer, ) - tmp_text = output_dir / f'tmp_text_{fragment_idx}.txt' - tmp_labels = output_dir / f'tmp_labels_{fragment_idx}.txt' - with text_file.open() as tf, labels_file.open() as lf, tmp_text.open('w') as otf, tmp_labels.open('w') as olf: - tf.seek(text_start_pos) - lf.seek(label_start_pos) - for _ in range(lines_per_dataset_fragment): - text_line = tf.readline() - if not text_line: - break - otf.write(text_line) - olf.write(lf.readline()) - dataset = BertPunctuationCapitalizationDataset( - tmp_text, - tmp_labels, - max_seq_length, - tokenizer, - tokens_in_batch=tokens_in_batch, - pad_label=pad_label, - punct_label_ids=punct_label_ids, - capit_label_ids=capit_label_ids, - n_jobs=0, - use_cache=False, - add_masks_and_segment_ids_to_batch=False, - verbose=False, - tokenization_progress_queue=tokenization_progress_queue, - batch_mark_up_progress_queue=batch_mark_up_progress_queue, - batch_building_progress_queue=batch_building_progress_queue, - ) - tmp_text.unlink() - tmp_labels.unlink() + tmp_text: Optional[str] = None + tmp_labels: Optional[str] = None + try: + otfd, tmp_text = tempfile.mkstemp(suffix='.txt', prefix=f'text_{fragment_idx}_', dir=output_dir, text=True) + olfd, tmp_labels = tempfile.mkstemp(suffix='.txt', prefix=f'labels_{fragment_idx}_', dir=output_dir, text=True) + with text_file.open() as tf, labels_file.open() as lf, os.fdopen(otfd, 'w') as otf, os.fdopen( + olfd, 'w' + ) as olf: + tf.seek(text_start_pos) + lf.seek(label_start_pos) + for _ in range(lines_per_dataset_fragment): + text_line = tf.readline() + if not text_line: + break + otf.write(text_line) + olf.write(lf.readline()) + dataset = BertPunctuationCapitalizationDataset( + tmp_text, + tmp_labels, + max_seq_length, + tokenizer, + tokens_in_batch=tokens_in_batch, + pad_label=pad_label, + punct_label_ids=punct_label_ids, + capit_label_ids=capit_label_ids, + n_jobs=0, + use_cache=False, + add_masks_and_segment_ids_to_batch=False, + verbose=False, + tokenization_progress_queue=tokenization_progress_queue, + batch_mark_up_progress_queue=batch_mark_up_progress_queue, + batch_building_progress_queue=batch_building_progress_queue, + ) + finally: + if tmp_text is not None and os.path.exists(tmp_text): + os.remove(tmp_text) + if tmp_labels is not None and os.path.exists(tmp_labels): + os.remove(tmp_labels) dataset.features_pkl.unlink() tar_ctr = 0 current_file_name = output_dir / TAR_FRAGMENT_TMPL_IN_PROGRESS.format(fragment_idx=fragment_idx, file_idx=tar_ctr)