From 3d9716658b1a5ec74c83cfd5e4ae73bd583afcd2 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 28 Apr 2026 22:41:49 -0500 Subject: [PATCH 1/2] Improved issue with tokenize and concatenate --- tests/unit/test_utils.py | 55 ++++++++++++++++++++ transformer_lens/utilities/tokenize_utils.py | 28 ++++++---- 2 files changed, 72 insertions(+), 11 deletions(-) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 52934239f..d703dd997 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -483,6 +483,61 @@ def test_no_split_tokens_across_chunks(self): f"This indicates a word was split across chunk boundaries." ) + def test_no_split_tokens_in_no_whitespace_text(self): + """No-whitespace multi-doc input — the prior whitespace-lookahead fix + fell through to character cuts here. streaming=True keeps all docs in + one tokenize_function call so EOS markers actually exist to split on. + """ + from datasets import Dataset + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + + docs = ["a" * 200 + "MilitaryVehicleEngine" * 100] * 10 + dataset = Dataset.from_dict({"text": docs}) + + result = utils.tokenize_and_concatenate( + dataset, + tokenizer, + streaming=True, + max_length=128, + add_bos_token=False, + ) + + full_text = tokenizer.eos_token.join(docs) + clean_tokens = tokenizer(full_text, return_tensors="np")["input_ids"].flatten() + clean_pairs = set(zip(clean_tokens[:-1], clean_tokens[1:])) + + output_tokens = np.concatenate([np.array(row["tokens"]) for row in result]) + for i in range(len(output_tokens) - 1): + pair = (int(output_tokens[i]), int(output_tokens[i + 1])) + assert pair in clean_pairs, ( + f"Token pair {pair} appears in tokenize_and_concatenate output " + f"but never occurs in natural tokenization. The chunker must " + f"have cut a token in half." + ) + + def test_single_document_batch_does_not_crash(self): + """Single-doc batch has no EOS to split on — fallback to one chunk should be correct.""" + from datasets import Dataset + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + dataset = Dataset.from_dict({"text": ["abcdefghij" * 200]}) + + result = utils.tokenize_and_concatenate( + dataset, + tokenizer, + streaming=True, + max_length=64, + add_bos_token=False, + ) + + clean = tokenizer("abcdefghij" * 200, return_tensors="np")["input_ids"].flatten() + output = np.concatenate([np.array(row["tokens"]) for row in result]) + n = len(output) + assert (output == clean[:n]).all() + def test_tokenize_and_concatenate_no_spurious_sequence_length_warning(): """Test that tokenize_and_concatenate does not emit the HF 'sequence length longer than maximum' warning.""" diff --git a/transformer_lens/utilities/tokenize_utils.py b/transformer_lens/utilities/tokenize_utils.py index 16a3ec4f7..1a0ab1319 100644 --- a/transformer_lens/utilities/tokenize_utils.py +++ b/transformer_lens/utilities/tokenize_utils.py @@ -67,24 +67,30 @@ def tokenize_function(examples: Any) -> dict[str, np.ndarray]: text = examples[column_name] # Concatenate it all into an enormous string, separated by eos_tokens assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token." - full_text = tokenizer.eos_token.join(text) + eos = tokenizer.eos_token + full_text = eos.join(text) # Handle the case when full_text is empty if not full_text.strip(): return {"tokens": np.array([], dtype=np.int64)} - # Split at whitespace boundaries to avoid mid-word tokens (#1133) + # Split at EOS boundaries — BPE merges don't cross EOS, so per-chunk + # tokenization concatenates to the same tokens as the joined string + # (#1133). No-EOS inputs fall through to one chunk: slower but correct. num_chunks = 20 - chunk_length = (len(full_text) - 1) // num_chunks + 1 - chunks = [] + target_chunk_size = (len(full_text) - 1) // num_chunks + 1 + chunks: list[str] = [] start = 0 - lookahead = chunk_length // 10 - for i in range(num_chunks): - end = min(start + chunk_length, len(full_text)) - # Advance to whitespace; bounded lookahead for pathological inputs - boundary = min(end + lookahead, len(full_text)) - while end < boundary and not full_text[end].isspace(): - end += 1 + while start < len(full_text): + target_end = start + target_chunk_size + if target_end >= len(full_text): + chunks.append(full_text[start:]) + break + eos_pos = full_text.find(eos, target_end) + if eos_pos == -1: + chunks.append(full_text[start:]) + break + end = eos_pos + len(eos) chunks.append(full_text[start:end]) start = end # Tokenize in parallel with NumPy (HF map rejects tensors) From 0b2662529af5812220ac2cdd3a2f0b84a8816cc8 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 28 Apr 2026 23:23:48 -0500 Subject: [PATCH 2/2] dialed in the approach to be per-doc --- transformer_lens/utilities/tokenize_utils.py | 78 +++++++------------- 1 file changed, 28 insertions(+), 50 deletions(-) diff --git a/transformer_lens/utilities/tokenize_utils.py b/transformer_lens/utilities/tokenize_utils.py index 1a0ab1319..c9b6dfefa 100644 --- a/transformer_lens/utilities/tokenize_utils.py +++ b/transformer_lens/utilities/tokenize_utils.py @@ -28,33 +28,30 @@ def tokenize_and_concatenate( add_bos_token: bool = True, num_proc: int = 10, ) -> Dataset: - """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end. + """Tokenize each document, join with token-level EOS between docs, and reshape into ``(batch, sequence_length)`` rows. - This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these) + Useful for training language models on a large text corpus without per-doc + truncation or padding. Absolute-position-embedding models also benefit by + avoiding early-token bias (e.g. news articles starting with "CNN"). Args: dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset. - tokenizer (PreTrainedTokenizerBase): The tokenizer. Assumed to have a bos_token_id and an eos_token_id. - streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False. + tokenizer (PreTrainedTokenizerBase): The tokenizer. Must have ``bos_token_id`` and ``eos_token_id``. + streaming (bool, optional): If True, avoids parallelism. Defaults to False. max_length (int, optional): The length of the context window of the sequence. Defaults to 1024. column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'. - add_bos_token (bool, optional): . Defaults to True. + add_bos_token (bool, optional): Whether to prepend ``bos_token_id`` to each output row. Defaults to True. Returns: - Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens" + Dataset: Tokenized dataset of tensors with a single column ``"tokens"``. """ dataset = keep_single_column(dataset, column_name) has_pad_token = tokenizer.pad_token is not None if not has_pad_token: - # Add padding token for tokenizer (removed before model input) tokenizer.add_special_tokens({"pad_token": ""}) - # Define the length to chop things up into - leaving space for a bos_token if required - if add_bos_token: - seq_len = max_length - 1 - else: - seq_len = max_length + seq_len = max_length - 1 if add_bos_token else max_length - # Suppress the "sequence length longer than maximum" warning during chunked tokenization. + # Long docs legitimately exceed model_max_length; we slice into rows after. _deprecation_warnings_saved = None if hasattr(tokenizer, "deprecation_warnings"): _deprecation_warnings_saved = tokenizer.deprecation_warnings.copy() @@ -63,56 +60,37 @@ def tokenize_and_concatenate( ] = False def tokenize_function(examples: Any) -> dict[str, np.ndarray]: - # datasets.map() may pass a LazyBatch, not a plain dict; accept dict-like batches text = examples[column_name] - # Concatenate it all into an enormous string, separated by eos_tokens assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token." - eos = tokenizer.eos_token - full_text = eos.join(text) - - # Handle the case when full_text is empty - if not full_text.strip(): + if not text: return {"tokens": np.array([], dtype=np.int64)} - # Split at EOS boundaries — BPE merges don't cross EOS, so per-chunk - # tokenization concatenates to the same tokens as the joined string - # (#1133). No-EOS inputs fall through to one chunk: slower but correct. - num_chunks = 20 - target_chunk_size = (len(full_text) - 1) // num_chunks + 1 - chunks: list[str] = [] - start = 0 - while start < len(full_text): - target_end = start + target_chunk_size - if target_end >= len(full_text): - chunks.append(full_text[start:]) - break - eos_pos = full_text.find(eos, target_end) - if eos_pos == -1: - chunks.append(full_text[start:]) - break - end = eos_pos + len(eos) - chunks.append(full_text[start:end]) - start = end - # Tokenize in parallel with NumPy (HF map rejects tensors) - tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten() - # Drop padding tokens - tokens = tokens[tokens != tokenizer.pad_token_id] + # Per-doc tokenization with explicit token-level EOS — string chunking + # could cut tokens mid-doc (#1133); add_special_tokens=False prevents + # SentencePiece tokenizers from scattering auto-BOS/EOS per call. + encoded = tokenizer(text, add_special_tokens=False)["input_ids"] + eos_id = tokenizer.eos_token_id + pieces: list[np.ndarray] = [] + for i, row in enumerate(encoded): + pieces.append(np.asarray(row, dtype=np.int64)) + if i < len(encoded) - 1: + pieces.append(np.array([eos_id], dtype=np.int64)) + if not pieces: + return {"tokens": np.array([], dtype=np.int64)} + tokens = np.concatenate(pieces) num_tokens = len(tokens) - # Handle cases where num_tokens is less than seq_len if num_tokens < seq_len: num_batches = 1 - # Pad tokens if necessary tokens = tokens[:seq_len] if len(tokens) < seq_len: - padding_length = seq_len - len(tokens) - # Use EOS as pad to avoid out-of-vocabulary IDs + # Pad with EOS when no native pad token to avoid OOV IDs. padding_id = tokenizer.eos_token_id if not has_pad_token else tokenizer.pad_token_id - padding = np.full(padding_length, padding_id) - tokens = np.concatenate([tokens, padding], axis=0) + tokens = np.concatenate( + [tokens, np.full(seq_len - len(tokens), padding_id)], axis=0 + ) else: num_batches = num_tokens // seq_len - # Drop the final tokens if not enough to make a full sequence tokens = tokens[: seq_len * num_batches] tokens = einops.rearrange(