From f21c56e7cceecb01e7e1bcfe20f7707a7bb0f5c7 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Wed, 3 Dec 2025 21:17:16 +0000 Subject: [PATCH] Add Eagle3DummyDataset This will be used to generate dummy data for future smoke tests. Signed-off-by: Fynn Schmitt-Ulms --- src/speculators/train/data.py | 83 +++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index 8fe08070..60a659f3 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -269,3 +269,86 @@ def collate_fn(batch: list[BatchType]) -> BatchType: return collated_data return collate_fn + + +class Eagle3DummyDataset(Dataset): + def __init__( + self, + num_samples: int, + hidden_size: int, + target_vocab_size: int, + sample_length_range: tuple[int, int] = (128, 2048), + hidden_states_dtype=torch.float, + seed: int = 0, + ): + """ + A dummy dataset for testing and debugging. + Args: + hidden_size: The hidden size of the model. + target_vocab_size: Used to limit the range of input_ids. + sample_length_range: The range of sample lengths to generate. + hidden_states_dtype: The dtype of the hidden states. + """ + self.seed = seed + self.num_samples = num_samples + self.hidden_size = hidden_size + self.target_vocab_size = target_vocab_size + self.sample_length_range = sample_length_range + self.hidden_states_dtype = hidden_states_dtype + self.approx_lengths = self._generate_fake_lengths() + + def __len__(self): + return self.num_samples + + def _generate_fake_lengths(self) -> list[int]: + """Generate fake lengths for the dataset""" + rng = random.Random(self.seed) + return [ + rng.randint(self.sample_length_range[0], self.sample_length_range[1]) + for _ in range(self.num_samples) + ] + + def __getitem__(self, index) -> BatchType: + # data structure: { + # "hidden_states": [seq_len, 3 * hidden_size], + # "input_ids": [seq_len], + # "verifier_last_hidden_states": [seq_len, hidden_size], + # "loss_mask": [seq_len], + # "lengths": [1], + # "position_ids": [seq_len], + # } + + torch_rng = torch.Generator(device="cpu").manual_seed(self.seed + index) + seq_len = self.approx_lengths[index] + + input_ids = torch.randint( + 0, self.target_vocab_size, (seq_len,), generator=torch_rng + ) + + loss_mask = torch.randint( + 0, 2, (seq_len,), generator=torch_rng, dtype=torch.long + ) + + hidden_states = torch.randn( + seq_len, + 3 * self.hidden_size, + dtype=self.hidden_states_dtype, + generator=torch_rng, + ) + verifier_last_hidden_states = torch.randn( + seq_len, + self.hidden_size, + dtype=self.hidden_states_dtype, + generator=torch_rng, + ) + + data = { + "hidden_states": hidden_states, + "input_ids": input_ids, + "verifier_last_hidden_states": verifier_last_hidden_states, + "loss_mask": loss_mask, + "lengths": torch.tensor([seq_len], dtype=torch.long), + "position_ids": torch.arange(1, seq_len + 1, dtype=torch.long), + } + + return data