Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions src/speculators/train/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,86 @@ def collate_fn(batch: list[BatchType]) -> BatchType:
return collated_data

return collate_fn


class Eagle3DummyDataset(Dataset):
def __init__(
self,
num_samples: int,
hidden_size: int,
target_vocab_size: int,
sample_length_range: tuple[int, int] = (128, 2048),
hidden_states_dtype=torch.float,
seed: int = 0,
):
"""
A dummy dataset for testing and debugging.
Args:
hidden_size: The hidden size of the model.
target_vocab_size: Used to limit the range of input_ids.
sample_length_range: The range of sample lengths to generate.
hidden_states_dtype: The dtype of the hidden states.
"""
self.seed = seed
self.num_samples = num_samples
self.hidden_size = hidden_size
self.target_vocab_size = target_vocab_size
self.sample_length_range = sample_length_range
self.hidden_states_dtype = hidden_states_dtype
self.approx_lengths = self._generate_fake_lengths()

def __len__(self):
return self.num_samples

def _generate_fake_lengths(self) -> list[int]:
"""Generate fake lengths for the dataset"""
rng = random.Random(self.seed)
return [
rng.randint(self.sample_length_range[0], self.sample_length_range[1])
for _ in range(self.num_samples)
]

def __getitem__(self, index) -> BatchType:
# data structure: {
# "hidden_states": [seq_len, 3 * hidden_size],
# "input_ids": [seq_len],
# "verifier_last_hidden_states": [seq_len, hidden_size],
# "loss_mask": [seq_len],
# "lengths": [1],
# "position_ids": [seq_len],
# }

torch_rng = torch.Generator(device="cpu").manual_seed(self.seed + index)
seq_len = self.approx_lengths[index]

input_ids = torch.randint(
0, self.target_vocab_size, (seq_len,), generator=torch_rng
)

loss_mask = torch.randint(
0, 2, (seq_len,), generator=torch_rng, dtype=torch.long
)

hidden_states = torch.randn(
seq_len,
3 * self.hidden_size,
dtype=self.hidden_states_dtype,
generator=torch_rng,
)
verifier_last_hidden_states = torch.randn(
seq_len,
self.hidden_size,
dtype=self.hidden_states_dtype,
generator=torch_rng,
)

data = {
"hidden_states": hidden_states,
"input_ids": input_ids,
"verifier_last_hidden_states": verifier_last_hidden_states,
"loss_mask": loss_mask,
"lengths": torch.tensor([seq_len], dtype=torch.long),
"position_ids": torch.arange(1, seq_len + 1, dtype=torch.long),
}

return data
Loading