Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,61 @@ def test_no_split_tokens_across_chunks(self):
f"This indicates a word was split across chunk boundaries."
)

def test_no_split_tokens_in_no_whitespace_text(self):
"""No-whitespace multi-doc input — the prior whitespace-lookahead fix
fell through to character cuts here. streaming=True keeps all docs in
one tokenize_function call so EOS markers actually exist to split on.
"""
from datasets import Dataset
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

docs = ["a" * 200 + "MilitaryVehicleEngine" * 100] * 10
dataset = Dataset.from_dict({"text": docs})

result = utils.tokenize_and_concatenate(
dataset,
tokenizer,
streaming=True,
max_length=128,
add_bos_token=False,
)

full_text = tokenizer.eos_token.join(docs)
clean_tokens = tokenizer(full_text, return_tensors="np")["input_ids"].flatten()
clean_pairs = set(zip(clean_tokens[:-1], clean_tokens[1:]))

output_tokens = np.concatenate([np.array(row["tokens"]) for row in result])
for i in range(len(output_tokens) - 1):
pair = (int(output_tokens[i]), int(output_tokens[i + 1]))
assert pair in clean_pairs, (
f"Token pair {pair} appears in tokenize_and_concatenate output "
f"but never occurs in natural tokenization. The chunker must "
f"have cut a token in half."
)

def test_single_document_batch_does_not_crash(self):
"""Single-doc batch has no EOS to split on — fallback to one chunk should be correct."""
from datasets import Dataset
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
dataset = Dataset.from_dict({"text": ["abcdefghij" * 200]})

result = utils.tokenize_and_concatenate(
dataset,
tokenizer,
streaming=True,
max_length=64,
add_bos_token=False,
)

clean = tokenizer("abcdefghij" * 200, return_tensors="np")["input_ids"].flatten()
output = np.concatenate([np.array(row["tokens"]) for row in result])
n = len(output)
assert (output == clean[:n]).all()


def test_tokenize_and_concatenate_no_spurious_sequence_length_warning():
"""Test that tokenize_and_concatenate does not emit the HF 'sequence length longer than maximum' warning."""
Expand Down
72 changes: 28 additions & 44 deletions transformer_lens/utilities/tokenize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,30 @@ def tokenize_and_concatenate(
add_bos_token: bool = True,
num_proc: int = 10,
) -> Dataset:
"""Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end.
"""Tokenize each document, join with token-level EOS between docs, and reshape into ``(batch, sequence_length)`` rows.

This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these)
Useful for training language models on a large text corpus without per-doc
truncation or padding. Absolute-position-embedding models also benefit by
avoiding early-token bias (e.g. news articles starting with "CNN").

Args:
dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset.
tokenizer (PreTrainedTokenizerBase): The tokenizer. Assumed to have a bos_token_id and an eos_token_id.
streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False.
tokenizer (PreTrainedTokenizerBase): The tokenizer. Must have ``bos_token_id`` and ``eos_token_id``.
streaming (bool, optional): If True, avoids parallelism. Defaults to False.
max_length (int, optional): The length of the context window of the sequence. Defaults to 1024.
column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'.
add_bos_token (bool, optional): . Defaults to True.
add_bos_token (bool, optional): Whether to prepend ``bos_token_id`` to each output row. Defaults to True.

Returns:
Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens"
Dataset: Tokenized dataset of tensors with a single column ``"tokens"``.
"""
dataset = keep_single_column(dataset, column_name)
has_pad_token = tokenizer.pad_token is not None
if not has_pad_token:
# Add padding token for tokenizer (removed before model input)
tokenizer.add_special_tokens({"pad_token": "<PAD>"})
# Define the length to chop things up into - leaving space for a bos_token if required
if add_bos_token:
seq_len = max_length - 1
else:
seq_len = max_length
seq_len = max_length - 1 if add_bos_token else max_length

# Suppress the "sequence length longer than maximum" warning during chunked tokenization.
# Long docs legitimately exceed model_max_length; we slice into rows after.
_deprecation_warnings_saved = None
if hasattr(tokenizer, "deprecation_warnings"):
_deprecation_warnings_saved = tokenizer.deprecation_warnings.copy()
Expand All @@ -63,50 +60,37 @@ def tokenize_and_concatenate(
] = False

def tokenize_function(examples: Any) -> dict[str, np.ndarray]:
# datasets.map() may pass a LazyBatch, not a plain dict; accept dict-like batches
text = examples[column_name]
# Concatenate it all into an enormous string, separated by eos_tokens
assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token."
full_text = tokenizer.eos_token.join(text)

# Handle the case when full_text is empty
if not full_text.strip():
if not text:
return {"tokens": np.array([], dtype=np.int64)}

# Split at whitespace boundaries to avoid mid-word tokens (#1133)
num_chunks = 20
chunk_length = (len(full_text) - 1) // num_chunks + 1
chunks = []
start = 0
lookahead = chunk_length // 10
for i in range(num_chunks):
end = min(start + chunk_length, len(full_text))
# Advance to whitespace; bounded lookahead for pathological inputs
boundary = min(end + lookahead, len(full_text))
while end < boundary and not full_text[end].isspace():
end += 1
chunks.append(full_text[start:end])
start = end
# Tokenize in parallel with NumPy (HF map rejects tensors)
tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten()
# Drop padding tokens
tokens = tokens[tokens != tokenizer.pad_token_id]
# Per-doc tokenization with explicit token-level EOS — string chunking
# could cut tokens mid-doc (#1133); add_special_tokens=False prevents
# SentencePiece tokenizers from scattering auto-BOS/EOS per call.
encoded = tokenizer(text, add_special_tokens=False)["input_ids"]
eos_id = tokenizer.eos_token_id
pieces: list[np.ndarray] = []
for i, row in enumerate(encoded):
pieces.append(np.asarray(row, dtype=np.int64))
if i < len(encoded) - 1:
pieces.append(np.array([eos_id], dtype=np.int64))
if not pieces:
return {"tokens": np.array([], dtype=np.int64)}
tokens = np.concatenate(pieces)
num_tokens = len(tokens)

# Handle cases where num_tokens is less than seq_len
if num_tokens < seq_len:
num_batches = 1
# Pad tokens if necessary
tokens = tokens[:seq_len]
if len(tokens) < seq_len:
padding_length = seq_len - len(tokens)
# Use EOS as pad to avoid out-of-vocabulary IDs
# Pad with EOS when no native pad token to avoid OOV IDs.
padding_id = tokenizer.eos_token_id if not has_pad_token else tokenizer.pad_token_id
padding = np.full(padding_length, padding_id)
tokens = np.concatenate([tokens, padding], axis=0)
tokens = np.concatenate(
[tokens, np.full(seq_len - len(tokens), padding_id)], axis=0
)
else:
num_batches = num_tokens // seq_len
# Drop the final tokens if not enough to make a full sequence
tokens = tokens[: seq_len * num_batches]

tokens = einops.rearrange(
Expand Down
Loading