diff --git a/openvalidators/dataset.py b/openvalidators/dataset.py index 215f10f..96683c4 100644 --- a/openvalidators/dataset.py +++ b/openvalidators/dataset.py @@ -27,11 +27,17 @@ def __init__(self): self.openwebtext = iter( load_dataset("openwebtext", split="train", streaming=True).shuffle(seed=seed, buffer_size=10000) ) self.red_pajama = iter( load_dataset("togethercomputer/RedPajama-Data-1T", 'default', split='train', streaming=True).shuffle(seed=seed, buffer_size=10000) ) - def __next__(self): - if random.random() < 0.5: - return {"text": next(self.openwebtext)["text"]} - else: - return {"text": next(self.red_pajama)["text"]} + def __next__(self): + while True: + bt.logging.debug('Retrieving data from dataset...') + if random.random() < 0.5: + text = next(self.openwebtext)["text"] + else: + text = next(self.red_pajama)["text"] + + # Check if the text is not empty or does not consist only of newline characters + if text.strip(): + return {"text": text} class MockDataset(Iterator): diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..c220768 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,26 @@ +import unittest +from openvalidators.dataset import Dataset + + +class DatasetTestCase(unittest.TestCase): + def test_next_skips_empty_and_newline_only_strings(self): + mock_data = iter([{"text": ""}, {"text": "\n\n"}, {"text": "Non-empty text"}]) + dataset = Dataset() + dataset.openwebtext = mock_data + dataset.red_pajama = mock_data + + # Test that __next__ skips empty texts and texts that consist only of newline characters + self.assertEqual(dataset.__next__(), {"text": "Non-empty text"}) + + def test_next_returns_regular_strings(self): + mock_data = iter([{"text": "Non-empty text"}]) + dataset = Dataset() + dataset.openwebtext = mock_data + dataset.red_pajama = mock_data + + # Test that __next__ returns a non-empty text + self.assertEqual(dataset.__next__(), {"text": "Non-empty text"}) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file