From 8034f71a1cf9de79a783f0a57789f8f08b25dfee Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 20 Dec 2025 03:05:48 -0500 Subject: [PATCH 1/9] rename to avoid naming collision --- pyhealth/datasets/base_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 6cad1e3d8..877d3c710 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -26,7 +26,7 @@ import requests from tqdm import tqdm import dask.dataframe as dd -from dask.distributed import Client, LocalCluster, progress +from dask.distributed import Client as DaskClient, LocalCluster as DaskCluster, progress as dask_progress import narwhals as nw from ..data import Patient @@ -373,13 +373,13 @@ def global_event_df(self) -> pl.LazyFrame: dask_scratch_dir = self.cache_dir / "dask_scratch" dask_scratch_dir.mkdir(parents=True, exist_ok=True) - with LocalCluster( + with DaskCluster( n_workers=self.num_workers, threads_per_worker=1, processes=not in_notebook(), local_directory=str(dask_scratch_dir), ) as cluster: - with Client(cluster) as client: + with DaskClient(cluster) as client: df: dd.DataFrame = self.load_data() if self.dev: logger.info("Dev mode enabled: limiting to 1000 patients") @@ -394,7 +394,7 @@ def global_event_df(self) -> pl.LazyFrame: compute=False, ) handle = client.compute(collection) - progress(handle) + dask_progress(handle) handle.result() # type: ignore self._global_event_df = ret_path From 31ee08fa870526753c2b1efe75fdaf5226ad6a07 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 20 Dec 2025 04:18:24 -0500 Subject: [PATCH 2/9] Support multi-worker for task transformation --- pyhealth/datasets/base_dataset.py | 144 +++++++++++++++++++++++------- 1 file changed, 114 insertions(+), 30 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 877d3c710..8e15215a2 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -2,9 +2,8 @@ import os import pickle from abc import ABC -from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from typing import Dict, Iterator, List, Optional, Any, Callable +from typing import Dict, Iterator, Iterable, List, Optional, Any, Callable import functools import operator from urllib.parse import urlparse, urlunparse @@ -13,7 +12,8 @@ import uuid import platformdirs import tempfile -import multiprocessing +from multiprocessing import current_process, Pool, Queue, Manager +import shutil import litdata from litdata.streaming.item_loader import ParquetLoader @@ -28,6 +28,7 @@ import dask.dataframe as dd from dask.distributed import Client as DaskClient, LocalCluster as DaskCluster, progress as dask_progress import narwhals as nw +import itertools from ..data import Patient from ..tasks import BaseTask @@ -113,6 +114,11 @@ def _csv_tsv_gz_path(path: str) -> str: def _uncollate(x: list[Any]) -> Any: return x[0] if isinstance(x, list) and len(x) == 1 else x +class _FakeQueue: + """A fake queue that does nothing. Used when multiprocessing is not needed.""" + + def put(self, item: Any) -> None: + pass class _ParquetWriter: """ @@ -191,6 +197,39 @@ def __exit__(self, exc_type, exc, tb): self.close() +def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, Path, Queue | _FakeQueue]) -> None: + """ + Worker function to apply task transformation on a chunk of patients. + + Args: + args (tuple): A tuple containing: + worker_id (int): The ID of the worker. + task (BaseTask): The task to apply. + patient_ids (Iterable[str]): The patient IDs to process. + global_event_df (pl.LazyFrame): The global event dataframe. + output_dir (Path): The output directory to save results. + queue (Queue | _FakeQueue): A multiprocessing queue for progress tracking. + """ + logger.info(f"Worker {args[0]} started processing {len(list(args[2]))} patients.") + + worker_id, task, patient_ids, global_event_df, output_dir, queue = args + with _ParquetWriter( + output_dir / f"chunk_{worker_id:03d}.parquet", + pa.schema([("sample", pa.binary())]), + ) as writer: + for patient_id in patient_ids: + patient_df = global_event_df.filter(pl.col("patient_id") == patient_id).collect( + engine="streaming" + ) + patient = Patient(patient_id=patient_id, data_source=patient_df) + + for sample in task(patient): + writer.append({"sample": pickle.dumps(sample)}) + + queue.put(1) + + logger.info(f"Worker {args[0]} finished processing patients.") + class BaseDataset(ABC): """Abstract base class for all PyHealth datasets. @@ -359,7 +398,7 @@ def global_event_df(self) -> pl.LazyFrame: Returns: Path: The path to the cached event dataframe. """ - if not multiprocessing.current_process().name == "MainProcess": + if not current_process().name == "MainProcess": logger.warning( "global_event_df property accessed from a non-main process. This may lead to unexpected behavior.\n" + "Consider use __name__ == '__main__' guard when using multiprocessing." @@ -590,6 +629,61 @@ def default_task(self) -> Optional[BaseTask]: """ return None + def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> None: + self._main_guard(self._task_transform.__name__) + + try: + logger.info(f"Applying task transformations on data with {num_workers} workers...") + global_event_df = task.pre_filter(self.global_event_df) + patient_ids = ( + global_event_df.select("patient_id") + .unique() + .collect(engine="streaming") + .to_series() + ) + + if in_notebook(): + logger.info("Detected Jupyter notebook environment, setting num_workers to 1") + num_workers = 1 + + if num_workers == 1: + logger.info("Single worker mode, processing sequentially") + _task_transform_fn((0, task, patient_ids, global_event_df, output_dir, _FakeQueue())) + return + + num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers + batch_size = len(patient_ids) // num_workers + 1 + with Manager() as manager: + queue = manager.Queue() + args_list = [( + worker_id, + task, + pids, + global_event_df, + output_dir, + queue + ) for worker_id, pids in enumerate(itertools.batched(patient_ids, batch_size))] + with Pool(processes=num_workers) as pool: + result = pool.map_async(_task_transform_fn, args_list) # type: ignore + with tqdm(total=len(patient_ids)) as progress: + while not result.ready(): + while not queue.empty(): + queue.get() + progress.update(1) + + # remaining items + while not queue.empty(): + queue.get() + progress.update(1) + + litdata.index_parquet_dataset(str(output_dir)) + logger.info(f"Task transformation completed and saved to {output_dir}") + except Exception as e: + logger.error(f"Error during task transformation, cleaning up output directory: {output_dir}") + shutil.rmtree(output_dir) + raise e + + def set_task( self, task: Optional[BaseTask] = None, @@ -622,12 +716,7 @@ def set_task( Raises: AssertionError: If no default task is found and task is None. """ - if not multiprocessing.current_process().name == "MainProcess": - logger.warning( - "set_task method accessed from a non-main process. This may lead to unexpected behavior.\n" - + "Consider use __name__ == '__main__' guard when using multiprocessing." - ) - return None # type: ignore + self._main_guard(self.set_task.__name__) if task is None: assert self.default_task is not None, "No default tasks found" @@ -656,27 +745,12 @@ def set_task( # Check if index.json exists to verify cache integrity, this # is the standard file for litdata.StreamingDataset if not (path / "index.json").exists(): - global_event_df = task.pre_filter(self.global_event_df) - schema = pa.schema([("sample", pa.binary())]) with tempfile.TemporaryDirectory() as tmp_dir: - # Create Parquet file with samples - logger.info(f"Applying task transformations on data...") - with _ParquetWriter(f"{tmp_dir}/samples.parquet", schema) as writer: - # TODO: this can be further optimized. - patient_ids = ( - global_event_df.select("patient_id") - .unique() - .collect(engine="streaming") - .to_series() - ) - for patient_id in tqdm(patient_ids): - patient_df = global_event_df.filter( - pl.col("patient_id") == patient_id - ).collect(engine="streaming") - patient = Patient(patient_id=patient_id, data_source=patient_df) - for sample in task(patient): - writer.append({"sample": pickle.dumps(sample)}) - litdata.index_parquet_dataset(tmp_dir) + self._task_transform( + task, + Path(tmp_dir), + num_workers, + ) # Build processors and fit on the dataset logger.info(f"Fitting processors on the dataset...") @@ -718,3 +792,13 @@ def set_task( dataset_name=self.dataset_name, task_name=task.task_name, ) + + def _main_guard(self, func_name: str): + """Warn if method is accessed from a non-main process.""" + + if not current_process().name == "MainProcess": + logger.warning( + f"{func_name} method accessed from a non-main process. This may lead to unexpected behavior.\n" + + "Consider use __name__ == '__main__' guard when using multiprocessing." + ) + exit(1) \ No newline at end of file From 5bae8f0ad46601e02b4081ab293b2aff4521e914 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 20 Dec 2025 04:28:14 -0500 Subject: [PATCH 3/9] Fix import --- pyhealth/datasets/base_dataset.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 8e15215a2..e25f68eae 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -12,7 +12,8 @@ import uuid import platformdirs import tempfile -from multiprocessing import current_process, Pool, Queue, Manager +import multiprocessing +import multiprocessing.queues import shutil import litdata @@ -197,7 +198,7 @@ def __exit__(self, exc_type, exc, tb): self.close() -def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, Path, Queue | _FakeQueue]) -> None: +def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, Path, multiprocessing.queues.Queue | _FakeQueue]) -> None: """ Worker function to apply task transformation on a chunk of patients. @@ -398,7 +399,7 @@ def global_event_df(self) -> pl.LazyFrame: Returns: Path: The path to the cached event dataframe. """ - if not current_process().name == "MainProcess": + if not multiprocessing.current_process().name == "MainProcess": logger.warning( "global_event_df property accessed from a non-main process. This may lead to unexpected behavior.\n" + "Consider use __name__ == '__main__' guard when using multiprocessing." @@ -653,7 +654,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers batch_size = len(patient_ids) // num_workers + 1 - with Manager() as manager: + with multiprocessing.Manager() as manager: queue = manager.Queue() args_list = [( worker_id, @@ -663,7 +664,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> output_dir, queue ) for worker_id, pids in enumerate(itertools.batched(patient_ids, batch_size))] - with Pool(processes=num_workers) as pool: + with multiprocessing.Pool(processes=num_workers) as pool: result = pool.map_async(_task_transform_fn, args_list) # type: ignore with tqdm(total=len(patient_ids)) as progress: while not result.ready(): @@ -796,7 +797,7 @@ def set_task( def _main_guard(self, func_name: str): """Warn if method is accessed from a non-main process.""" - if not current_process().name == "MainProcess": + if not multiprocessing.current_process().name == "MainProcess": logger.warning( f"{func_name} method accessed from a non-main process. This may lead to unexpected behavior.\n" + "Consider use __name__ == '__main__' guard when using multiprocessing." From 2461e30f0bf7e170bc9936635dfffb1b6d80da97 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 20 Dec 2025 04:39:43 -0500 Subject: [PATCH 4/9] Fix deadlock --- pyhealth/datasets/base_dataset.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index e25f68eae..9401c31c8 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -654,7 +654,10 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers batch_size = len(patient_ids) // num_workers + 1 - with multiprocessing.Manager() as manager: + + # spwan is required for polars in multiprocessing, see https://docs.pola.rs/user-guide/misc/multiprocessing/#summary + ctx = multiprocessing.get_context("spawn") + with ctx.Manager() as manager: queue = manager.Queue() args_list = [( worker_id, @@ -664,7 +667,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> output_dir, queue ) for worker_id, pids in enumerate(itertools.batched(patient_ids, batch_size))] - with multiprocessing.Pool(processes=num_workers) as pool: + with ctx.Pool(processes=num_workers) as pool: result = pool.map_async(_task_transform_fn, args_list) # type: ignore with tqdm(total=len(patient_ids)) as progress: while not result.ready(): From 725e4778ae7fe23affb7262d43fea6f00991f167 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 20 Dec 2025 04:50:45 -0500 Subject: [PATCH 5/9] better IPC communication --- pyhealth/datasets/base_dataset.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 9401c31c8..8ac95fe19 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -211,9 +211,12 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P output_dir (Path): The output directory to save results. queue (Queue | _FakeQueue): A multiprocessing queue for progress tracking. """ + UPDATE_FREQUENCY = 128 + logger.info(f"Worker {args[0]} started processing {len(list(args[2]))} patients.") worker_id, task, patient_ids, global_event_df, output_dir, queue = args + count = 0 with _ParquetWriter( output_dir / f"chunk_{worker_id:03d}.parquet", pa.schema([("sample", pa.binary())]), @@ -223,11 +226,17 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P engine="streaming" ) patient = Patient(patient_id=patient_id, data_source=patient_df) - for sample in task(patient): writer.append({"sample": pickle.dumps(sample)}) - queue.put(1) + count += 1 + if count >= UPDATE_FREQUENCY: + queue.put(count) + count = 0 + + if count > 0: + queue.put(count) + count = 0 logger.info(f"Worker {args[0]} finished processing patients.") @@ -672,13 +681,11 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> with tqdm(total=len(patient_ids)) as progress: while not result.ready(): while not queue.empty(): - queue.get() - progress.update(1) + progress.update(queue.get()) # remaining items while not queue.empty(): - queue.get() - progress.update(1) + progress.update(queue.get()) litdata.index_parquet_dataset(str(output_dir)) logger.info(f"Task transformation completed and saved to {output_dir}") From 7cb5683e278240a7308b4979a9f7a371616a6294 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 20 Dec 2025 05:24:32 -0500 Subject: [PATCH 6/9] fix bug when num_workers == 1 --- pyhealth/datasets/base_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 8ac95fe19..244945827 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -659,6 +659,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> if num_workers == 1: logger.info("Single worker mode, processing sequentially") _task_transform_fn((0, task, patient_ids, global_event_df, output_dir, _FakeQueue())) + litdata.index_parquet_dataset(str(output_dir)) return num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers From 707a0b98743bf8d46c6b4051708d4b114de06f03 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 23 Dec 2025 06:45:05 -0500 Subject: [PATCH 7/9] Fix UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown --- pyhealth/datasets/base_dataset.py | 35 +++++++++++++++---------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 244945827..8b4e3a2d4 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -667,26 +667,25 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> # spwan is required for polars in multiprocessing, see https://docs.pola.rs/user-guide/misc/multiprocessing/#summary ctx = multiprocessing.get_context("spawn") - with ctx.Manager() as manager: - queue = manager.Queue() - args_list = [( - worker_id, - task, - pids, - global_event_df, - output_dir, - queue - ) for worker_id, pids in enumerate(itertools.batched(patient_ids, batch_size))] - with ctx.Pool(processes=num_workers) as pool: - result = pool.map_async(_task_transform_fn, args_list) # type: ignore - with tqdm(total=len(patient_ids)) as progress: - while not result.ready(): - while not queue.empty(): - progress.update(queue.get()) - - # remaining items + queue = ctx.Queue() + args_list = [( + worker_id, + task, + pids, + global_event_df, + output_dir, + queue + ) for worker_id, pids in enumerate(itertools.batched(patient_ids, batch_size))] + with ctx.Pool(processes=num_workers) as pool: + result = pool.map_async(_task_transform_fn, args_list) # type: ignore + with tqdm(total=len(patient_ids)) as progress: + while not result.ready(): while not queue.empty(): progress.update(queue.get()) + + # remaining items + while not queue.empty(): + progress.update(queue.get()) litdata.index_parquet_dataset(str(output_dir)) logger.info(f"Task transformation completed and saved to {output_dir}") From 24e3bd3299177101510410005fffd4435cffc6be Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 23 Dec 2025 07:31:22 -0500 Subject: [PATCH 8/9] Fix crash --- pyhealth/datasets/base_dataset.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 8b4e3a2d4..36d843e2f 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -115,11 +115,6 @@ def _csv_tsv_gz_path(path: str) -> str: def _uncollate(x: list[Any]) -> Any: return x[0] if isinstance(x, list) and len(x) == 1 else x -class _FakeQueue: - """A fake queue that does nothing. Used when multiprocessing is not needed.""" - - def put(self, item: Any) -> None: - pass class _ParquetWriter: """ @@ -197,8 +192,19 @@ def __enter__(self): def __exit__(self, exc_type, exc, tb): self.close() +_task_transform_queue: multiprocessing.queues.Queue | None = None + +def _task_transform_init(queue: multiprocessing.queues.Queue) -> None: + """ + Initializer for worker processes to set up a global queue. -def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, Path, multiprocessing.queues.Queue | _FakeQueue]) -> None: + Args: + queue (multiprocessing.queues.Queue): The queue for progress tracking. + """ + global _task_transform_queue + _task_transform_queue = queue + +def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, Path]) -> None: """ Worker function to apply task transformation on a chunk of patients. @@ -209,13 +215,15 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P patient_ids (Iterable[str]): The patient IDs to process. global_event_df (pl.LazyFrame): The global event dataframe. output_dir (Path): The output directory to save results. - queue (Queue | _FakeQueue): A multiprocessing queue for progress tracking. """ UPDATE_FREQUENCY = 128 logger.info(f"Worker {args[0]} started processing {len(list(args[2]))} patients.") - worker_id, task, patient_ids, global_event_df, output_dir, queue = args + worker_id, task, patient_ids, global_event_df, output_dir = args + queue = _task_transform_queue + assert queue is not None, "Queue not initialized in worker process." + count = 0 with _ParquetWriter( output_dir / f"chunk_{worker_id:03d}.parquet", @@ -658,7 +666,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> if num_workers == 1: logger.info("Single worker mode, processing sequentially") - _task_transform_fn((0, task, patient_ids, global_event_df, output_dir, _FakeQueue())) + _task_transform_fn((0, task, patient_ids, global_event_df, output_dir)) litdata.index_parquet_dataset(str(output_dir)) return @@ -674,9 +682,8 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> pids, global_event_df, output_dir, - queue ) for worker_id, pids in enumerate(itertools.batched(patient_ids, batch_size))] - with ctx.Pool(processes=num_workers) as pool: + with ctx.Pool(processes=num_workers, initializer=_task_transform_init, initargs=(queue,)) as pool: result = pool.map_async(_task_transform_fn, args_list) # type: ignore with tqdm(total=len(patient_ids)) as progress: while not result.ready(): @@ -686,6 +693,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> # remaining items while not queue.empty(): progress.update(queue.get()) + result.get() # ensure exceptions are raised litdata.index_parquet_dataset(str(output_dir)) logger.info(f"Task transformation completed and saved to {output_dir}") From 742df32629c90ba5a41b0325c04d1181185c74d9 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 23 Dec 2025 07:43:03 -0500 Subject: [PATCH 9/9] Fix single worker mode. --- pyhealth/datasets/base_dataset.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 36d843e2f..46dd41029 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -216,13 +216,16 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P global_event_df (pl.LazyFrame): The global event dataframe. output_dir (Path): The output directory to save results. """ + class _FakeQueue: + def put(self, x): + pass + UPDATE_FREQUENCY = 128 logger.info(f"Worker {args[0]} started processing {len(list(args[2]))} patients.") worker_id, task, patient_ids, global_event_df, output_dir = args - queue = _task_transform_queue - assert queue is not None, "Queue not initialized in worker process." + queue = _task_transform_queue or _FakeQueue() count = 0 with _ParquetWriter(