diff --git a/.vscode/launch.json b/.vscode/launch.json index a6527357..5a81778d 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -85,6 +85,15 @@ "console": "integratedTerminal", "justMyCode": true }, + { + "name": "build tiny stories", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/rib_scripts/rib_build/run_rib_build.py", + "args": "${workspaceFolder}/rib_scripts/rib_build/tinystories.yaml", + "console": "integratedTerminal", + "justMyCode": true + }, { "name": "build edges_pythia-14m", "type": "python", diff --git a/rib/data.py b/rib/data.py index 4085ac20..d1aab7c8 100644 --- a/rib/data.py +++ b/rib/data.py @@ -1,4 +1,4 @@ -"""Define custom datasets.""" +"""Defines the dataset configs and datasets used in RIB.""" from typing import Literal, Optional @@ -22,39 +22,33 @@ class DatasetConfig(BaseModel): """Base class for dataset configs.""" model_config = ConfigDict(extra="forbid", frozen=True) - return_set: Literal["train", "test", "all"] = Field( + return_set: Literal["train", "test", "validation", "all"] = Field( "train", description="The dataset split to return. If 'all', returns the combined train and test " "datasets.", ) return_set_frac: Optional[float] = Field( None, - description="The fraction of the returned dataset (train/test/all) to use. Cannot be" - "used with return_set_n_samples.", + description="The fraction of the returned dataset (train/test/validation/all) to load. " + "This will be sampled from using n_samples if n_samples is not None.", ) - return_set_n_samples: Optional[int] = Field( + n_samples: Optional[int] = Field( None, - description="The number of raw samples to return from the dataset (train/test/all). " - "Cannot be used with return_set_frac.", + description="The number of n_ctx length tokenized samples to load from the dataset. This " + "will be sampled from either return_set_frac or n_documents if they are not None, or the " + "entire dataset if they are None. If n_samples is None, will load all samples in " + "return_set_frac (or n_documents if provided in a child class).", ) @model_validator(mode="after") - def verify_return_set_frac_and_n_samples(self) -> "DatasetConfig": - """Verify not both return_set_frac and return_set_n_samples are set and check values.""" + def verify_return_set_options(self) -> "DatasetConfig": + """Can't have both return_set_frac and n_samples be non-None for dataset with n_documents.""" frac = self.return_set_frac + if not hasattr(self, "n_documents") and (frac is not None and self.n_samples is not None): + raise ValueError( + "Cannot have both return_set_frac and n_samples be non-None for this dataset." + ) - if frac is not None: - if self.return_set_n_samples is not None: - raise ValueError( - "Cannot have both return_set_frac and return_set_n_samples be non-None." - ) - if isinstance(self, HFDatasetConfig) and (frac < 0.01 or frac > 1): - raise ValueError( - f"return_set_frac must be > 0.01 and < 1 since huggingface dataset `split` " - f"method does not correctly convert other values to perecentages." - ) - if frac <= 0 or frac > 1: - raise ValueError(f"return_set_frac must be > 0 and <= 1.") return self @@ -70,17 +64,39 @@ class HFDatasetConfig(DatasetConfig): description="The HuggingFace name for the tokenizer. Please check whether the tokenizer is " "compatible with the model you are using.", ) - return_set: Literal["train", "test"] = Field( + return_set: Literal["train", "test", "validation"] = Field( ..., description="The dataset split to return from HuggingFace." ) return_set_portion: Literal["first", "last"] = Field( "first", description="Whether to load the first or last portion of the return_set." ) + n_documents: Optional[int] = Field( + None, + description="The number of documents to load from the dataset before (optional) sampling " + "with n_samples. If None, will load all documents in return_set_frac (or all documents if " + "return_set_frac is None).", + ) n_ctx: Optional[int] = Field( None, description="Dataset will be packed to sequences of this length. Should be <1024 for gpt2." "<2048 for most other models.", ) + seed: Optional[int] = Field(0, description="The random seed value for reproducibility.") + + @model_validator(mode="after") + def verify_return_set_options(self) -> "HFDatasetConfig": + frac = self.return_set_frac + # Can't have both return_set_frac and n_documents be non-None + if frac is not None and self.n_documents is not None: + raise ValueError( + "Cannot have both return_set_frac and n_documents be non-None for HF datasets." + ) + if frac is not None and (frac < 0.01 or frac > 1): + raise ValueError( + f"return_set_frac must be > 0.01 and < 1 since huggingface dataset `split` " + f"method does not correctly convert other values to perecentages." + ) + return self class ModularArithmeticDatasetConfig(DatasetConfig): @@ -136,7 +152,7 @@ class VisionDatasetConfig(DatasetConfig): seed: Optional[int] = 0 return_set: Literal["train", "test"] = "train" return_set_frac: Optional[float] = None # Needed for some reason to avoid mypy errors - return_set_n_samples: Optional[int] = None # Needed for some reason to avoid mypy errors + n_samples: Optional[int] = None # Needed for some reason to avoid mypy errors class BlockVectorDatasetConfig(DatasetConfig): diff --git a/rib/loader.py b/rib/loader.py index 4701efe1..9161b217 100644 --- a/rib/loader.py +++ b/rib/loader.py @@ -208,7 +208,7 @@ def create_modular_arithmetic_dataset( dataset_subset = get_data_subset( dataset, frac=dataset_config.return_set_frac, - n_samples=dataset_config.return_set_n_samples, + n_samples=dataset_config.n_samples, seed=seed, ) return dataset_subset @@ -218,46 +218,73 @@ def tokenize_dataset( dataset: Dataset, tokenizer: AutoTokenizer, n_ctx: int, + n_samples: Optional[int] = None, + seed: Optional[int] = None, ) -> TensorDataset: """Tokenize a dataset using the provided tokenizer. Tokenizes the dataset and splits it into chunks that fit the context length. The labels are the input_ids shifted by one position. + The final chunk is not included in the dataset as it does not have a label for its final token. + Excluding it also means that we don't have to worry about padding. + Args: raw_dataset (Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]): The raw dataset to tokenize. Created from `hf_load_dataset`. tokenizer (AutoTokenizer): The tokenizer to use. n_ctx (int): The context length to use. + n_samples (Optional[int]): The number of samples to use. If None, uses all samples. + seed (Optional[int]): The seed to use for sampling. Returns: TensorDataset: The tokenized dataset. """ - # Tokenize all samples and merge them together + assert tokenizer.eos_token_id is not None, "Tokenizer must have an eos token id" + # Tokenize all samples and merge them into one long list of tokens all_tokens = [] for example in dataset: # type: ignore tokens = tokenizer(example["text"])["input_ids"] - all_tokens.extend(tokens) - + # Add the eos token to the end of each sample as was done in the original training + # https://github.com/EleutherAI/pythia/issues/123#issuecomment-1791136253 + all_tokens.extend(tokens + [tokenizer.eos_token_id]) + + # There shouldn't be any padding tokens, so ensure that there are len(dataset) eos tokens + len_dataset = len(dataset) # type: ignore + assert all_tokens.count(tokenizer.eos_token_id) == len_dataset, ( + f"Number of eos tokens ({all_tokens.count(tokenizer.eos_token_id)}) does not match " + f"number of samples ({len_dataset})." + ) # Split the merged tokens into chunks that fit the context length - chunks = [all_tokens[i : i + n_ctx] for i in range(0, len(all_tokens), n_ctx)] - - # Convert chunks to input_ids and labels - # we ignore the final chunk, as it contains a token we don't have a label for - # and is also probably too short and we don't want to pad. - input_ids_list = [] - labels_list = [] - for i, chunk in enumerate(chunks[:-1]): - input_id = chunk - label = input_id[1:] + [chunks[i + 1][0]] # with first token from next chunk + raw_chunks = [all_tokens[i : i + n_ctx] for i in range(0, len(all_tokens), n_ctx)] + + # Note that we ignore the final raw_chunk, as we get the label for the final token in a chunk + # from the subsequent chunk. + n_raw_chunks = len(raw_chunks) - 1 + if n_samples is not None: + # Randomly select n_samples chunks + generator = torch.Generator() if seed is None else torch.Generator().manual_seed(seed) + raw_chunk_idxs = torch.randperm(n_raw_chunks, generator=generator) + assert len(raw_chunk_idxs) >= n_samples, ( + f"Cannot sample {n_samples} chunks from dataset with {len(raw_chunks)} chunks of " + f"length {n_ctx}." + ) + chunk_idxs = raw_chunk_idxs[:n_samples].tolist() + else: + chunk_idxs = list(range(n_raw_chunks)) - input_ids_list.append(input_id) - labels_list.append(label) + chunks = [raw_chunks[i] for i in chunk_idxs] - input_ids = torch.tensor(input_ids_list, dtype=torch.long) - labels = torch.tensor(labels_list, dtype=torch.long) + all_labels: list[list[int]] = [] + for i, chunk in enumerate(chunks): + # Get the label for the last token using the next chunk in raw_chunks + final_token_label = raw_chunks[chunk_idxs[i] + 1][0] + labels = chunk[1:] + [final_token_label] + all_labels.append(labels) - return TensorDataset(input_ids, labels) + return TensorDataset( + torch.tensor(chunks, dtype=torch.long), torch.tensor(all_labels, dtype=torch.long) + ) def create_hf_dataset( @@ -297,25 +324,41 @@ def create_hf_dataset( f"({model_n_ctx})." ) - assert dataset_config.return_set in ["train", "test"], "Only train and test sets are supported" + assert dataset_config.return_set in [ + "train", + "test", + "validation", + ], f"Invalid return_set: {dataset_config.return_set}. Must be one of train, test, validation." if dataset_config.return_set_frac: + # Sample from all documents in return_set_frac% of return_set_portion percent = int(dataset_config.return_set_frac * 100) if dataset_config.return_set_portion == "first": data_split = f"{dataset_config.return_set}[:{percent}%]" elif dataset_config.return_set_portion == "last": data_split = f"{dataset_config.return_set}[-{percent}%:]" - elif dataset_config.return_set_n_samples: + elif dataset_config.n_documents: + # Only load the first/last n documents from return_set and sample n_samples. if dataset_config.return_set_portion == "first": - data_split = f"{dataset_config.return_set}[:{dataset_config.return_set_n_samples}]" + data_split = f"{dataset_config.return_set}[:{dataset_config.n_documents}]" elif dataset_config.return_set_portion == "last": - data_split = f"{dataset_config.return_set}[-{dataset_config.return_set_n_samples}:]" + data_split = f"{dataset_config.return_set}[-{dataset_config.n_documents}:]" + else: + # Sample n_samples from all documents in return_set + data_split = dataset_config.return_set raw_dataset = hf_load_dataset(dataset_config.name, split=data_split) tokenizer = AutoTokenizer.from_pretrained(dataset_config.tokenizer_name) tokenizer.pad_token = tokenizer.eos_token - return tokenize_dataset(dataset=raw_dataset, tokenizer=tokenizer, n_ctx=n_ctx) + tokenized_dataset = tokenize_dataset( + dataset=raw_dataset, + tokenizer=tokenizer, + n_ctx=n_ctx, + n_samples=dataset_config.n_samples, + seed=dataset_config.seed, + ) + return tokenized_dataset def create_vision_dataset(dataset_config: VisionDatasetConfig) -> Dataset: @@ -331,7 +374,7 @@ def create_vision_dataset(dataset_config: VisionDatasetConfig) -> Dataset: dataset = get_data_subset( raw_dataset, frac=dataset_config.return_set_frac, - n_samples=dataset_config.return_set_n_samples, + n_samples=dataset_config.n_samples, seed=dataset_config.seed, ) return dataset @@ -343,7 +386,7 @@ def create_block_vector_dataset(dataset_config: BlockVectorDatasetConfig) -> Dat dataset = get_data_subset( raw_dataset, frac=dataset_config.return_set_frac, - n_samples=dataset_config.return_set_n_samples, + n_samples=dataset_config.n_samples, seed=dataset_config.seed, ) return dataset diff --git a/rib/rib_builder.py b/rib/rib_builder.py index 780dd3d7..f38b425a 100644 --- a/rib/rib_builder.py +++ b/rib/rib_builder.py @@ -299,11 +299,11 @@ def _verify_compatible_configs(config: RibBuildConfig, loaded_config: RibBuildCo assert ( config.dataset.return_set_frac <= loaded_config.dataset.return_set_frac ), "Cannot use a larger return_set_frac for edges than to calculate the Cs" - elif config.dataset.return_set_n_samples is not None: - assert loaded_config.dataset.return_set_n_samples is not None + elif config.dataset.n_samples is not None: + assert loaded_config.dataset.n_samples is not None assert ( - config.dataset.return_set_n_samples <= loaded_config.dataset.return_set_n_samples - ), "Cannot use a larger return_set_n_samples for edges than to calculate the Cs" + config.dataset.n_samples <= loaded_config.dataset.n_samples + ), "Cannot use a larger n_samples for edges than to calculate the Cs" def load_interaction_rotations( diff --git a/rib_scripts/ablations/orthog_pythia-14m.yaml b/rib_scripts/ablations/orthog_pythia-14m.yaml index 07a6acdb..12cc55eb 100644 --- a/rib_scripts/ablations/orthog_pythia-14m.yaml +++ b/rib_scripts/ablations/orthog_pythia-14m.yaml @@ -11,7 +11,7 @@ dataset: tokenizer_name: EleutherAI/pythia-14m return_set: train # pile-10k only has train, so we take the first 90% for building and last 10% for ablations return_set_frac: 0.01 - return_set_n_samples: null + n_samples: null return_set_portion: last ablation_node_layers: - mlp_out.0 diff --git a/rib_scripts/ablations/rib_pythia-14m.yaml b/rib_scripts/ablations/rib_pythia-14m.yaml index a08e3dca..1b5a30d1 100644 --- a/rib_scripts/ablations/rib_pythia-14m.yaml +++ b/rib_scripts/ablations/rib_pythia-14m.yaml @@ -12,7 +12,7 @@ dataset: tokenizer_name: EleutherAI/pythia-14m return_set: train # pile-10k only has train, so we take the first 90% for building and last 10% for ablations return_set_frac: 0.01 - return_set_n_samples: null + n_samples: null return_set_portion: first ablation_node_layers: - mlp_out.0 diff --git a/rib_scripts/rib_build/Cs_pythia-14m.yaml b/rib_scripts/rib_build/Cs_pythia-14m.yaml index 270eb881..ea55ae33 100644 --- a/rib_scripts/rib_build/Cs_pythia-14m.yaml +++ b/rib_scripts/rib_build/Cs_pythia-14m.yaml @@ -9,7 +9,7 @@ dataset: tokenizer_name: EleutherAI/pythia-14m return_set: train # pile-10k only has train, so we take the first 90% for building and last 10% for ablations return_set_frac: 0.9 - return_set_n_samples: null + n_samples: null return_set_portion: first node_layers: - mlp_out.0 diff --git a/rib_scripts/rib_build/edges_pythia-14m.yaml b/rib_scripts/rib_build/edges_pythia-14m.yaml index baf30d7e..363040a3 100644 --- a/rib_scripts/rib_build/edges_pythia-14m.yaml +++ b/rib_scripts/rib_build/edges_pythia-14m.yaml @@ -8,9 +8,11 @@ dataset: name: NeelNanda/pile-10k tokenizer_name: EleutherAI/pythia-14m return_set: train # pile-10k only has train, so we take the first 90% for building and last 10% for ablations - return_set_frac: null - return_set_n_samples: 20 + return_set_frac: 0.1 + n_documents: null + n_samples: null return_set_portion: first + n_ctx: 50 node_layers: - mlp_out.0 - ln2.3 diff --git a/rib_scripts/rib_build/tinystories.yaml b/rib_scripts/rib_build/tinystories.yaml index 739135df..32cdd605 100644 --- a/rib_scripts/rib_build/tinystories.yaml +++ b/rib_scripts/rib_build/tinystories.yaml @@ -8,7 +8,8 @@ dataset: tokenizer_name: EleutherAI/gpt-neo-125M return_set: train return_set_frac: null - return_set_n_samples: 5000 # avg ~235 toks / story + n_documents: 5000 # avg ~235 toks / story + n_samples: 3000 return_set_portion: first n_ctx: 256 # needs to be <= 511 for the model to behave reasonably node_layers: diff --git a/tests/test_ablations.py b/tests/test_ablations.py index 8637dac1..f4ebd5ea 100644 --- a/tests/test_ablations.py +++ b/tests/test_ablations.py @@ -112,7 +112,7 @@ def test_run_mnist_ablations(ablation_type, tmp_path): "node_layers": ["layers.1", "layers.2", "output"], "batch_size": 100, "dtype": "float32", - "dataset": {"return_set_n_samples": 100, "return_set_frac": None}, + "dataset": {"n_samples": 100, "return_set_frac": None}, } ) results = rib_build(build_config) @@ -135,7 +135,7 @@ def test_run_mnist_ablations(ablation_type, tmp_path): dataset: dataset_type: torchvision name: MNIST - return_set_n_samples: 100 + n_samples: 100 batch_size: 64 # 2 batches seed: 0 out_dir: null @@ -154,7 +154,7 @@ def test_run_modular_arithmetic_rib_ablations(ablation_type, tmp_path): { "node_layers": ["ln1.0", "ln2.0", "mlp_out.0", "unembed", "output"], "batch_size": 100, - "dataset": {"return_set_n_samples": 100}, + "dataset": {"n_samples": 100}, } ) results = rib_build(build_config) @@ -174,7 +174,7 @@ def test_run_modular_arithmetic_rib_ablations(ablation_type, tmp_path): dataset: dataset_type: modular_arithmetic return_set: train - return_set_n_samples: 100 + n_samples: 100 ablation_node_layers: - ln1.0 - ln2.0 @@ -199,7 +199,7 @@ def test_run_mnist_ablations_bisect(ablation_type, tmp_path): "node_layers": ["layers.1", "layers.2", "output"], "batch_size": 100, "dtype": "float32", - "dataset": {"return_set_n_samples": 100, "return_set_frac": None}, + "dataset": {"n_samples": 100, "return_set_frac": None}, } ) results = rib_build(build_config) @@ -220,7 +220,7 @@ def test_run_mnist_ablations_bisect(ablation_type, tmp_path): dataset: dataset_type: torchvision name: MNIST - return_set_n_samples: 100 + n_samples: 100 batch_size: 64 # two batches seed: 0 out_dir: null @@ -239,7 +239,7 @@ def test_run_modular_arithmetic_rib_ablations_bisect(ablation_type, tmp_path): { "node_layers": ["ln1.0", "ln2.0", "mlp_out.0", "unembed", "output"], "batch_size": 100, - "dataset": {"return_set_n_samples": 100}, + "dataset": {"n_samples": 100}, } ) results = rib_build(build_config) @@ -256,7 +256,7 @@ def test_run_modular_arithmetic_rib_ablations_bisect(ablation_type, tmp_path): dataset: dataset_type: modular_arithmetic return_set: train - return_set_n_samples: 100 + n_samples: 100 ablation_node_layers: - ln1.0 - ln2.0 diff --git a/tests/test_build_graph.py b/tests/test_build_graph.py index fdf42336..88527a20 100644 --- a/tests/test_build_graph.py +++ b/tests/test_build_graph.py @@ -206,26 +206,24 @@ def test_modular_arithmetic_build_graph(basis_formula, edge_formula): @pytest.mark.slow def test_pythia_14m_build_graph(): atol = 0 # Works with 1e-7 for float32 and 0 for float64 - config = get_pythia_config() + config = get_pythia_config({"dataset": {"n_ctx": None}}) results = graph_build_test(config=config, atol=atol) get_rib_acts_test(results, atol=0) @pytest.mark.slow -def test_pythia_14m_build_graph_jacobian(): +def test_pythia_14m_build_graph_jacobian_stochastic(): atol = 0 # Works with 0 for batch_size 900 but not 1800 - updates = [ - # Runs in around 30s on a5000 - {"basis_formula": "jacobian"}, - {"dataset": {"return_set_n_samples": 1}}, - {"dataset": {"n_ctx": 2}}, - {"batch_size": 900}, - {"node_layers": ["ln2.1", "mlp_out.5", "unembed"]}, - {"calculate_edges": True}, - {"edge_formula": "squared"}, - {"n_stochastic_sources_edges": 1}, - ] - config = get_pythia_config(*updates) + config = get_pythia_config( + { + "basis_formula": "jacobian", + "dataset": {"n_documents": 10, "n_samples": 1, "n_ctx": 2}, + "node_layers": ["ln2.1", "mlp_out.5", "unembed"], + "calculate_edges": True, + "edge_formula": "squared", + "n_stochastic_sources_edges": 1, + } + ) results = graph_build_test(config=config, atol=atol) get_rib_acts_test(results, atol=0) @@ -637,13 +635,13 @@ def test_stochastic_source_modadd_convergence(): NOTE: This is quite a weak test, but the runs a slow so we're taking a hit on the test quality. """ node_layers = ["mlp_in.0", "mlp_out.0"] - return_set_n_samples = 3 + n_samples = 3 batch_size = 3 # Calc squared edges config_squared = get_modular_arithmetic_config( { - "dataset": {"return_set_n_samples": return_set_n_samples}, + "dataset": {"n_samples": n_samples}, "batch_size": batch_size, "edge_formula": "squared", "node_layers": node_layers, @@ -660,7 +658,7 @@ def test_stochastic_source_modadd_convergence(): for n_stochastic_sources_edges in [1, 3, 7]: config_stochastic = get_modular_arithmetic_config( { - "dataset": {"return_set_n_samples": return_set_n_samples}, + "dataset": {"n_samples": n_samples}, "batch_size": batch_size, "edge_formula": "squared", "node_layers": node_layers, @@ -689,8 +687,7 @@ def no_stoc_result(): @pytest.mark.parametrize( ["pos_sources", "hidden_sources", "error"], [ - [None, 10, 0.2], - [None, 40, 0.07], + [None, 40, 0.1], [2, None, 0.07], [2, 40, 0.1], ], diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 00000000..016a7ca2 --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,67 @@ +import pytest + +from rib.data import ( + BlockVectorDatasetConfig, + HFDatasetConfig, + ModularArithmeticDatasetConfig, + VisionDatasetConfig, +) +from rib.utils import replace_pydantic_model + + +def test_hf_dataset_config_validation(): + """Test the validation of the HFDatasetConfig model. + + For HF datasets, we can't have both return_set_frac and n_documents be non-None, but we can + have all other combinations. + """ + base_config = HFDatasetConfig( + dataset_type="huggingface", + name="test", + tokenizer_name="test", + return_set="train", + ) + valid_combinations = [ + {"return_set_frac": 0.5, "n_samples": None, "n_documents": None}, + {"return_set_frac": None, "n_samples": 10, "n_documents": None}, + {"return_set_frac": None, "n_samples": None, "n_documents": 10}, + {"return_set_frac": 0.1, "n_samples": 10, "n_documents": None}, + {"return_set_frac": None, "n_samples": 10, "n_documents": 10}, + ] + for combination in valid_combinations: + replace_pydantic_model(base_config, combination) + + with pytest.raises(ValueError): + # Can't have both return_set_frac and n_documents be non-None + replace_pydantic_model( + base_config, {"return_set_frac": 0.5, "n_samples": 10, "n_documents": 10} + ) + with pytest.raises(ValueError): + # Frac is < 0.01 + replace_pydantic_model(base_config, {"return_set_frac": 0.001, "n_samples": None}) + with pytest.raises(ValueError): + # Frac is > 1 + replace_pydantic_model(base_config, {"return_set_frac": 1.1, "n_samples": None}) + + +def test_non_hf_dataset_config_validation(): + """Test the validation of dataset configs that are not HFDatasetConfig. + + We can't have both return_set_frac and n_samples be non-None. + """ + for base_config in [ + BlockVectorDatasetConfig(dataset_type="block_vector"), + VisionDatasetConfig(dataset_type="torchvision"), + ModularArithmeticDatasetConfig(dataset_type="modular_arithmetic"), + ]: + valid_combinations = [ + {"return_set_frac": 0.5, "n_samples": None}, + {"return_set_frac": None, "n_samples": 10}, + {"return_set_frac": None, "n_samples": None}, + ] + for combination in valid_combinations: + replace_pydantic_model(base_config, combination) + + with pytest.raises(ValueError): + # invalid combination + replace_pydantic_model(base_config, {"return_set_frac": 0.5, "n_samples": 10}) diff --git a/tests/test_float_precision.py b/tests/test_float_precision.py index 35ec136a..6ab314e8 100644 --- a/tests/test_float_precision.py +++ b/tests/test_float_precision.py @@ -32,7 +32,8 @@ def rib_results(self, temp_object) -> dict[str, RibBuildResults]: tokenizer_name: EleutherAI/pythia-14m return_set: train # pile-10k only has train, so we take the first 90% for building and last 10% for ablations return_set_frac: null - return_set_n_samples: 10 + n_documents: 30 + n_samples: 3 return_set_portion: first node_layers: - mlp_out.0 @@ -140,9 +141,10 @@ def ablation_results(self, temp_object, rib_results) -> dict: dataset_type: huggingface name: NeelNanda/pile-10k tokenizer_name: EleutherAI/pythia-14m - return_set: train # pile-10k only has train, so we take the first 90% for building and last 10% for ablations + return_set: train return_set_frac: null - return_set_n_samples: 10 + n_documents: 30 + n_samples: 3 return_set_portion: first ablation_node_layers: - mlp_out.0 diff --git a/tests/test_loader.py b/tests/test_loader.py index ef7a09aa..1457cdca 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -2,8 +2,8 @@ import torch from torch.utils.data import Subset, TensorDataset -from rib.loader import load_sequential_transformer -from rib.utils import get_data_subset +from rib.loader import load_sequential_transformer, tokenize_dataset +from rib.utils import get_data_subset, set_seed @pytest.mark.parametrize( @@ -58,3 +58,89 @@ def test_load_transformer(model_str): device="cpu", fold_bias=True, ) + + +class MockTokenizer: + def __init__(self): + self.eos_token_id = 0 # Example EOS token ID + self.generator = torch.Generator().manual_seed(0) + + def __call__(self, text: str) -> dict[str, list[int]]: + # Generate between 5 to 10 token IDs as an example + num_tokens = torch.randint(5, 10, (1,), generator=self.generator).item() + token_ids = torch.randint(1, 100, (num_tokens,), generator=self.generator).tolist() + return {"input_ids": token_ids} + + +class TestTokenizeDataset: + @pytest.fixture(autouse=True) + def setup_class(self): + self.sample_texts = ["This is a test.", "Another test sentence."] + self.sample_dataset = [{"text": text} for text in self.sample_texts] + # Create a dummy tokenizer that spits out random tokens + set_seed(0) + self.tokenizer = MockTokenizer() + + def test_outputs_are_all_n_ctx_length(self): + n_ctx = 5 + tokenized_dataset = tokenize_dataset(self.sample_dataset, self.tokenizer, n_ctx) + for input_ids, labels in tokenized_dataset: + assert len(input_ids) == n_ctx + assert len(labels) == n_ctx + + def test_dataset_has_expected_size(self): + n_ctx = 5 + n_samples = 3 + tokenized_dataset = tokenize_dataset(self.sample_dataset, self.tokenizer, n_ctx, n_samples) + assert len(tokenized_dataset) == n_samples + + def test_seed_reproducibility(self): + n_ctx = 5 + n_samples = 2 + seed = 0 + dataset1 = tokenize_dataset(self.sample_dataset, self.tokenizer, n_ctx, n_samples, seed) + duplicate_tokenizer = MockTokenizer() + dataset2 = tokenize_dataset( + self.sample_dataset, duplicate_tokenizer, n_ctx, n_samples, seed + ) + assert torch.equal(dataset1.tensors[0], dataset2.tensors[0]) and torch.equal( + dataset1.tensors[1], dataset2.tensors[1] + ) + + def test_different_seeds(self): + n_ctx = 5 + n_samples = 2 + dataset1 = tokenize_dataset(self.sample_dataset, self.tokenizer, n_ctx, n_samples, 42) + duplicate_tokenizer = MockTokenizer() + dataset2 = tokenize_dataset(self.sample_dataset, duplicate_tokenizer, n_ctx, n_samples, 43) + assert not torch.equal(dataset1.tensors[0], dataset2.tensors[0]) or not torch.equal( + dataset1.tensors[1], dataset2.tensors[1] + ) + + def test_input_ids_equal_labels_no_sampling(self): + """If not sampling (i.e. n_samples is None), input_ids and labels differ by one token. + + Moreover, the final label of one chunk is the input_id of the first token in the next chunk. + So we can flatten the input_ids and labels and check that they are equal (offset by one). + """ + n_ctx = 5 + tokenized_dataset = tokenize_dataset(self.sample_dataset, self.tokenizer, n_ctx) + flattened_input_ids = [ + token_id for input_ids, _ in tokenized_dataset for token_id in input_ids + ] + flattened_labels = [token_id for _, labels in tokenized_dataset for token_id in labels] + assert len(flattened_input_ids) == len(flattened_labels) + assert flattened_input_ids[1:] == flattened_labels[:-1] + + def test_input_ids_equal_labels_sampling(self): + """Check that the labels match the input_ids except for the final token when sampling. + + When (randomly) sampling, the chunks will not be ordered, so we can't check that the final + token label is the input_id of the first token in the next chunk. + """ + n_ctx = 5 + n_samples = 3 + tokenized_dataset = tokenize_dataset(self.sample_dataset, self.tokenizer, n_ctx, n_samples) + for input_ids, labels in tokenized_dataset: + assert len(input_ids) == len(labels) + assert torch.equal(input_ids[1:], labels[:-1]) diff --git a/tests/utils.py b/tests/utils.py index 9fd652b4..948ad243 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -19,7 +19,7 @@ def get_modular_arithmetic_config(*updates: dict) -> RibBuildConfig: dataset: dataset_type: modular_arithmetic return_set: train - return_set_n_samples: 10 + n_samples: 10 node_layers: - ln1.0 - mlp_in.0 @@ -53,8 +53,11 @@ def get_pythia_config(*updates: dict) -> RibBuildConfig: tokenizer_name: EleutherAI/pythia-14m return_set: train return_set_frac: null - return_set_n_samples: 10 # 10 samples gives 3x2048 tokens + n_documents: 20 + n_samples: 3 return_set_portion: first + n_ctx: 128 + seed: 0 node_layers: - ln2.1 - unembed @@ -64,6 +67,7 @@ def get_pythia_config(*updates: dict) -> RibBuildConfig: n_intervals: 0 dtype: float64 calculate_edges: false + edge_formula: squared eval_type: ce_loss out_dir: null basis_formula: (1-0)*alpha @@ -85,7 +89,8 @@ def get_tinystories_config(*updates: dict) -> RibBuildConfig: tokenizer_name: EleutherAI/gpt-neo-125M return_set: train return_set_frac: null - return_set_n_samples: 1 # avg ~235 toks / story + n_documents: 1 # avg ~235 toks / story + n_samples: 15 return_set_portion: first n_ctx: 10 # needs to be <= 511 for the model to behave reasonably node_layers: @@ -213,9 +218,7 @@ def _assignment_permutations(sim: torch.Tensor) -> tuple[list[int], list[int]]: def assert_basis_similarity( - ir_A: InteractionRotation, - ir_B: InteractionRotation, - error: Optional[float] = 0.02, + ir_A: InteractionRotation, ir_B: InteractionRotation, error: Optional[float] = 0.02 ): """ Compare two InteractionRotations and assert similarity, allowing for permutations.