Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "build tiny stories",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/rib_scripts/rib_build/run_rib_build.py",
"args": "${workspaceFolder}/rib_scripts/rib_build/tinystories.yaml",
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "build edges_pythia-14m",
"type": "python",
Expand Down
62 changes: 39 additions & 23 deletions rib/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Define custom datasets."""
"""Defines the dataset configs and datasets used in RIB."""

from typing import Literal, Optional

Expand All @@ -22,39 +22,33 @@ class DatasetConfig(BaseModel):
"""Base class for dataset configs."""

model_config = ConfigDict(extra="forbid", frozen=True)
return_set: Literal["train", "test", "all"] = Field(
return_set: Literal["train", "test", "validation", "all"] = Field(
"train",
description="The dataset split to return. If 'all', returns the combined train and test "
"datasets.",
)
return_set_frac: Optional[float] = Field(
None,
description="The fraction of the returned dataset (train/test/all) to use. Cannot be"
"used with return_set_n_samples.",
description="The fraction of the returned dataset (train/test/validation/all) to load. "
"This will be sampled from using n_samples if n_samples is not None.",
)
return_set_n_samples: Optional[int] = Field(
n_samples: Optional[int] = Field(
None,
description="The number of raw samples to return from the dataset (train/test/all). "
"Cannot be used with return_set_frac.",
description="The number of n_ctx length tokenized samples to load from the dataset. This "
"will be sampled from either return_set_frac or n_documents if they are not None, or the "
"entire dataset if they are None. If n_samples is None, will load all samples in "
"return_set_frac (or n_documents if provided in a child class).",
)

@model_validator(mode="after")
def verify_return_set_frac_and_n_samples(self) -> "DatasetConfig":
"""Verify not both return_set_frac and return_set_n_samples are set and check values."""
def verify_return_set_options(self) -> "DatasetConfig":
"""Can't have both return_set_frac and n_samples be non-None for dataset with n_documents."""
frac = self.return_set_frac
if not hasattr(self, "n_documents") and (frac is not None and self.n_samples is not None):
raise ValueError(
"Cannot have both return_set_frac and n_samples be non-None for this dataset."
)

if frac is not None:
if self.return_set_n_samples is not None:
raise ValueError(
"Cannot have both return_set_frac and return_set_n_samples be non-None."
)
if isinstance(self, HFDatasetConfig) and (frac < 0.01 or frac > 1):
raise ValueError(
f"return_set_frac must be > 0.01 and < 1 since huggingface dataset `split` "
f"method does not correctly convert other values to perecentages."
)
if frac <= 0 or frac > 1:
raise ValueError(f"return_set_frac must be > 0 and <= 1.")
return self


Expand All @@ -70,17 +64,39 @@ class HFDatasetConfig(DatasetConfig):
description="The HuggingFace name for the tokenizer. Please check whether the tokenizer is "
"compatible with the model you are using.",
)
return_set: Literal["train", "test"] = Field(
return_set: Literal["train", "test", "validation"] = Field(
..., description="The dataset split to return from HuggingFace."
)
return_set_portion: Literal["first", "last"] = Field(
"first", description="Whether to load the first or last portion of the return_set."
)
n_documents: Optional[int] = Field(
None,
description="The number of documents to load from the dataset before (optional) sampling "
"with n_samples. If None, will load all documents in return_set_frac (or all documents if "
"return_set_frac is None).",
)
n_ctx: Optional[int] = Field(
None,
description="Dataset will be packed to sequences of this length. Should be <1024 for gpt2."
"<2048 for most other models.",
)
seed: Optional[int] = Field(0, description="The random seed value for reproducibility.")

@model_validator(mode="after")
def verify_return_set_options(self) -> "HFDatasetConfig":
frac = self.return_set_frac
# Can't have both return_set_frac and n_documents be non-None
if frac is not None and self.n_documents is not None:
raise ValueError(
"Cannot have both return_set_frac and n_documents be non-None for HF datasets."
)
if frac is not None and (frac < 0.01 or frac > 1):
raise ValueError(
f"return_set_frac must be > 0.01 and < 1 since huggingface dataset `split` "
f"method does not correctly convert other values to perecentages."
)
return self


class ModularArithmeticDatasetConfig(DatasetConfig):
Expand Down Expand Up @@ -136,7 +152,7 @@ class VisionDatasetConfig(DatasetConfig):
seed: Optional[int] = 0
return_set: Literal["train", "test"] = "train"
return_set_frac: Optional[float] = None # Needed for some reason to avoid mypy errors
return_set_n_samples: Optional[int] = None # Needed for some reason to avoid mypy errors
n_samples: Optional[int] = None # Needed for some reason to avoid mypy errors


class BlockVectorDatasetConfig(DatasetConfig):
Expand Down
95 changes: 69 additions & 26 deletions rib/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def create_modular_arithmetic_dataset(
dataset_subset = get_data_subset(
dataset,
frac=dataset_config.return_set_frac,
n_samples=dataset_config.return_set_n_samples,
n_samples=dataset_config.n_samples,
seed=seed,
)
return dataset_subset
Expand All @@ -218,46 +218,73 @@ def tokenize_dataset(
dataset: Dataset,
tokenizer: AutoTokenizer,
n_ctx: int,
n_samples: Optional[int] = None,
seed: Optional[int] = None,
) -> TensorDataset:
"""Tokenize a dataset using the provided tokenizer.

Tokenizes the dataset and splits it into chunks that fit the context length. The labels are
the input_ids shifted by one position.

The final chunk is not included in the dataset as it does not have a label for its final token.
Excluding it also means that we don't have to worry about padding.

Args:
raw_dataset (Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]): The raw
dataset to tokenize. Created from `hf_load_dataset`.
tokenizer (AutoTokenizer): The tokenizer to use.
n_ctx (int): The context length to use.
n_samples (Optional[int]): The number of samples to use. If None, uses all samples.
seed (Optional[int]): The seed to use for sampling.

Returns:
TensorDataset: The tokenized dataset.
"""
# Tokenize all samples and merge them together
assert tokenizer.eos_token_id is not None, "Tokenizer must have an eos token id"
# Tokenize all samples and merge them into one long list of tokens
all_tokens = []
for example in dataset: # type: ignore
tokens = tokenizer(example["text"])["input_ids"]
all_tokens.extend(tokens)

# Add the eos token to the end of each sample as was done in the original training
# https://github.com/EleutherAI/pythia/issues/123#issuecomment-1791136253
all_tokens.extend(tokens + [tokenizer.eos_token_id])
Comment thread
danbraunai-apollo marked this conversation as resolved.

# There shouldn't be any padding tokens, so ensure that there are len(dataset) eos tokens
len_dataset = len(dataset) # type: ignore
assert all_tokens.count(tokenizer.eos_token_id) == len_dataset, (
f"Number of eos tokens ({all_tokens.count(tokenizer.eos_token_id)}) does not match "
f"number of samples ({len_dataset})."
)
# Split the merged tokens into chunks that fit the context length
chunks = [all_tokens[i : i + n_ctx] for i in range(0, len(all_tokens), n_ctx)]

# Convert chunks to input_ids and labels
# we ignore the final chunk, as it contains a token we don't have a label for
# and is also probably too short and we don't want to pad.
input_ids_list = []
labels_list = []
for i, chunk in enumerate(chunks[:-1]):
input_id = chunk
label = input_id[1:] + [chunks[i + 1][0]] # with first token from next chunk
raw_chunks = [all_tokens[i : i + n_ctx] for i in range(0, len(all_tokens), n_ctx)]

# Note that we ignore the final raw_chunk, as we get the label for the final token in a chunk
# from the subsequent chunk.
n_raw_chunks = len(raw_chunks) - 1
if n_samples is not None:
# Randomly select n_samples chunks
generator = torch.Generator() if seed is None else torch.Generator().manual_seed(seed)
raw_chunk_idxs = torch.randperm(n_raw_chunks, generator=generator)
assert len(raw_chunk_idxs) >= n_samples, (
f"Cannot sample {n_samples} chunks from dataset with {len(raw_chunks)} chunks of "
f"length {n_ctx}."
)
chunk_idxs = raw_chunk_idxs[:n_samples].tolist()
else:
chunk_idxs = list(range(n_raw_chunks))

input_ids_list.append(input_id)
labels_list.append(label)
chunks = [raw_chunks[i] for i in chunk_idxs]

input_ids = torch.tensor(input_ids_list, dtype=torch.long)
labels = torch.tensor(labels_list, dtype=torch.long)
all_labels: list[list[int]] = []
for i, chunk in enumerate(chunks):
# Get the label for the last token using the next chunk in raw_chunks
final_token_label = raw_chunks[chunk_idxs[i] + 1][0]
labels = chunk[1:] + [final_token_label]
all_labels.append(labels)

return TensorDataset(input_ids, labels)
return TensorDataset(
torch.tensor(chunks, dtype=torch.long), torch.tensor(all_labels, dtype=torch.long)
)


def create_hf_dataset(
Expand Down Expand Up @@ -297,25 +324,41 @@ def create_hf_dataset(
f"({model_n_ctx})."
)

assert dataset_config.return_set in ["train", "test"], "Only train and test sets are supported"
assert dataset_config.return_set in [
"train",
"test",
"validation",
], f"Invalid return_set: {dataset_config.return_set}. Must be one of train, test, validation."

if dataset_config.return_set_frac:
# Sample from all documents in return_set_frac% of return_set_portion
percent = int(dataset_config.return_set_frac * 100)
if dataset_config.return_set_portion == "first":
data_split = f"{dataset_config.return_set}[:{percent}%]"
elif dataset_config.return_set_portion == "last":
data_split = f"{dataset_config.return_set}[-{percent}%:]"
elif dataset_config.return_set_n_samples:
elif dataset_config.n_documents:
# Only load the first/last n documents from return_set and sample n_samples.
if dataset_config.return_set_portion == "first":
data_split = f"{dataset_config.return_set}[:{dataset_config.return_set_n_samples}]"
data_split = f"{dataset_config.return_set}[:{dataset_config.n_documents}]"
elif dataset_config.return_set_portion == "last":
data_split = f"{dataset_config.return_set}[-{dataset_config.return_set_n_samples}:]"
data_split = f"{dataset_config.return_set}[-{dataset_config.n_documents}:]"
else:
# Sample n_samples from all documents in return_set
data_split = dataset_config.return_set

raw_dataset = hf_load_dataset(dataset_config.name, split=data_split)

tokenizer = AutoTokenizer.from_pretrained(dataset_config.tokenizer_name)
tokenizer.pad_token = tokenizer.eos_token
return tokenize_dataset(dataset=raw_dataset, tokenizer=tokenizer, n_ctx=n_ctx)
tokenized_dataset = tokenize_dataset(
dataset=raw_dataset,
tokenizer=tokenizer,
n_ctx=n_ctx,
n_samples=dataset_config.n_samples,
seed=dataset_config.seed,
)
return tokenized_dataset


def create_vision_dataset(dataset_config: VisionDatasetConfig) -> Dataset:
Expand All @@ -331,7 +374,7 @@ def create_vision_dataset(dataset_config: VisionDatasetConfig) -> Dataset:
dataset = get_data_subset(
raw_dataset,
frac=dataset_config.return_set_frac,
n_samples=dataset_config.return_set_n_samples,
n_samples=dataset_config.n_samples,
seed=dataset_config.seed,
)
return dataset
Expand All @@ -343,7 +386,7 @@ def create_block_vector_dataset(dataset_config: BlockVectorDatasetConfig) -> Dat
dataset = get_data_subset(
raw_dataset,
frac=dataset_config.return_set_frac,
n_samples=dataset_config.return_set_n_samples,
n_samples=dataset_config.n_samples,
seed=dataset_config.seed,
)
return dataset
Expand Down
8 changes: 4 additions & 4 deletions rib/rib_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,11 @@ def _verify_compatible_configs(config: RibBuildConfig, loaded_config: RibBuildCo
assert (
config.dataset.return_set_frac <= loaded_config.dataset.return_set_frac
), "Cannot use a larger return_set_frac for edges than to calculate the Cs"
elif config.dataset.return_set_n_samples is not None:
assert loaded_config.dataset.return_set_n_samples is not None
elif config.dataset.n_samples is not None:
assert loaded_config.dataset.n_samples is not None
assert (
config.dataset.return_set_n_samples <= loaded_config.dataset.return_set_n_samples
), "Cannot use a larger return_set_n_samples for edges than to calculate the Cs"
config.dataset.n_samples <= loaded_config.dataset.n_samples
), "Cannot use a larger n_samples for edges than to calculate the Cs"


def load_interaction_rotations(
Expand Down
2 changes: 1 addition & 1 deletion rib_scripts/ablations/orthog_pythia-14m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dataset:
tokenizer_name: EleutherAI/pythia-14m
return_set: train # pile-10k only has train, so we take the first 90% for building and last 10% for ablations
return_set_frac: 0.01
return_set_n_samples: null
n_samples: null
return_set_portion: last
ablation_node_layers:
- mlp_out.0
Expand Down
2 changes: 1 addition & 1 deletion rib_scripts/ablations/rib_pythia-14m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dataset:
tokenizer_name: EleutherAI/pythia-14m
return_set: train # pile-10k only has train, so we take the first 90% for building and last 10% for ablations
return_set_frac: 0.01
return_set_n_samples: null
n_samples: null
return_set_portion: first
ablation_node_layers:
- mlp_out.0
Expand Down
2 changes: 1 addition & 1 deletion rib_scripts/rib_build/Cs_pythia-14m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dataset:
tokenizer_name: EleutherAI/pythia-14m
return_set: train # pile-10k only has train, so we take the first 90% for building and last 10% for ablations
return_set_frac: 0.9
return_set_n_samples: null
n_samples: null
return_set_portion: first
node_layers:
- mlp_out.0
Expand Down
6 changes: 4 additions & 2 deletions rib_scripts/rib_build/edges_pythia-14m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ dataset:
name: NeelNanda/pile-10k
tokenizer_name: EleutherAI/pythia-14m
return_set: train # pile-10k only has train, so we take the first 90% for building and last 10% for ablations
return_set_frac: null
return_set_n_samples: 20
return_set_frac: 0.1
n_documents: null
n_samples: null
return_set_portion: first
n_ctx: 50
node_layers:
- mlp_out.0
- ln2.3
Expand Down
3 changes: 2 additions & 1 deletion rib_scripts/rib_build/tinystories.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ dataset:
tokenizer_name: EleutherAI/gpt-neo-125M
return_set: train
return_set_frac: null
return_set_n_samples: 5000 # avg ~235 toks / story
n_documents: 5000 # avg ~235 toks / story
n_samples: 3000
return_set_portion: first
n_ctx: 256 # needs to be <= 511 for the model to behave reasonably
node_layers:
Expand Down
Loading