diff --git a/python/python/ci_benchmarks/datagen/lineitems.py b/python/python/ci_benchmarks/datagen/lineitems.py index b91c1c3b422..39ad637fb95 100644 --- a/python/python/ci_benchmarks/datagen/lineitems.py +++ b/python/python/ci_benchmarks/datagen/lineitems.py @@ -3,6 +3,9 @@ # Creates a dataset containing the TPC-H lineitems table using a prebuilt Parquet file +import shutil +import tempfile + import duckdb import lance from lance.log import LOGGER @@ -12,16 +15,17 @@ NUM_ROWS = 59986052 -def _gen_data(scale_factor: int): +def _gen_data(tmpdir: str, scale_factor: int): LOGGER.info("Using DuckDB to generate TPC-H dataset") - con = duckdb.connect(database=":memory:") + con = duckdb.connect(f"{tmpdir}/tpch-scale-factor-{scale_factor}.db") con.execute("INSTALL tpch; LOAD tpch") con.execute(f"CALL dbgen(sf={scale_factor})") res = con.query("SELECT * FROM lineitem") - return res.to_arrow_table() + return res.fetch_arrow_reader() def _create(dataset_uri: str, data_storage_version: str, scale_factor: int = 10): + tmpdir = tempfile.mkdtemp(prefix=f"tpch-scale-factor-{scale_factor}-") try: ds = lance.dataset(dataset_uri) print(ds.count_rows()) @@ -29,7 +33,7 @@ def _create(dataset_uri: str, data_storage_version: str, scale_factor: int = 10) return elif ds.count_rows() == 0: ds = lance.write_dataset( - _gen_data(scale_factor), + _gen_data(tmpdir, scale_factor), dataset_uri, mode="append", data_storage_version=data_storage_version, @@ -42,11 +46,13 @@ def _create(dataset_uri: str, data_storage_version: str, scale_factor: int = 10) ) except ValueError: ds = lance.write_dataset( - _gen_data(scale_factor), + _gen_data(tmpdir, scale_factor), dataset_uri, mode="create", data_storage_version=data_storage_version, ) + finally: + shutil.rmtree(tmpdir) return ds