Skip to content
Merged
170 changes: 138 additions & 32 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
import os
import pickle
from abc import ABC
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Any, Callable
from typing import Dict, Iterator, Iterable, List, Optional, Any, Callable
import functools
import operator
from urllib.parse import urlparse, urlunparse
Expand All @@ -14,6 +13,8 @@
import platformdirs
import tempfile
import multiprocessing
import multiprocessing.queues
import shutil

import litdata
from litdata.streaming.item_loader import ParquetLoader
Expand All @@ -26,8 +27,9 @@
import requests
from tqdm import tqdm
import dask.dataframe as dd
from dask.distributed import Client, LocalCluster, progress
from dask.distributed import Client as DaskClient, LocalCluster as DaskCluster, progress as dask_progress
import narwhals as nw
import itertools

from ..data import Patient
from ..tasks import BaseTask
Expand Down Expand Up @@ -190,6 +192,64 @@ def __enter__(self):
def __exit__(self, exc_type, exc, tb):
self.close()

_task_transform_queue: multiprocessing.queues.Queue | None = None

def _task_transform_init(queue: multiprocessing.queues.Queue) -> None:
"""
Initializer for worker processes to set up a global queue.

Args:
queue (multiprocessing.queues.Queue): The queue for progress tracking.
"""
global _task_transform_queue
_task_transform_queue = queue

def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, Path]) -> None:
"""
Worker function to apply task transformation on a chunk of patients.

Args:
args (tuple): A tuple containing:
worker_id (int): The ID of the worker.
task (BaseTask): The task to apply.
patient_ids (Iterable[str]): The patient IDs to process.
global_event_df (pl.LazyFrame): The global event dataframe.
output_dir (Path): The output directory to save results.
"""
class _FakeQueue:
def put(self, x):
pass

UPDATE_FREQUENCY = 128

logger.info(f"Worker {args[0]} started processing {len(list(args[2]))} patients.")

worker_id, task, patient_ids, global_event_df, output_dir = args
queue = _task_transform_queue or _FakeQueue()

count = 0
with _ParquetWriter(
output_dir / f"chunk_{worker_id:03d}.parquet",
pa.schema([("sample", pa.binary())]),
) as writer:
for patient_id in patient_ids:
patient_df = global_event_df.filter(pl.col("patient_id") == patient_id).collect(
engine="streaming"
)
patient = Patient(patient_id=patient_id, data_source=patient_df)
for sample in task(patient):
writer.append({"sample": pickle.dumps(sample)})

count += 1
if count >= UPDATE_FREQUENCY:
queue.put(count)
count = 0

if count > 0:
queue.put(count)
count = 0

logger.info(f"Worker {args[0]} finished processing patients.")

class BaseDataset(ABC):
"""Abstract base class for all PyHealth datasets.
Expand Down Expand Up @@ -373,13 +433,13 @@ def global_event_df(self) -> pl.LazyFrame:
dask_scratch_dir = self.cache_dir / "dask_scratch"
dask_scratch_dir.mkdir(parents=True, exist_ok=True)

with LocalCluster(
with DaskCluster(
n_workers=self.num_workers,
threads_per_worker=1,
processes=not in_notebook(),
local_directory=str(dask_scratch_dir),
) as cluster:
with Client(cluster) as client:
with DaskClient(cluster) as client:
df: dd.DataFrame = self.load_data()
if self.dev:
logger.info("Dev mode enabled: limiting to 1000 patients")
Expand All @@ -394,7 +454,7 @@ def global_event_df(self) -> pl.LazyFrame:
compute=False,
)
handle = client.compute(collection)
progress(handle)
dask_progress(handle)
handle.result() # type: ignore
self._global_event_df = ret_path

Expand Down Expand Up @@ -590,6 +650,62 @@ def default_task(self) -> Optional[BaseTask]:
"""
return None

def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> None:
self._main_guard(self._task_transform.__name__)

try:
logger.info(f"Applying task transformations on data with {num_workers} workers...")
global_event_df = task.pre_filter(self.global_event_df)
patient_ids = (
global_event_df.select("patient_id")
.unique()
.collect(engine="streaming")
.to_series()
)

if in_notebook():
logger.info("Detected Jupyter notebook environment, setting num_workers to 1")
num_workers = 1

if num_workers == 1:
logger.info("Single worker mode, processing sequentially")
_task_transform_fn((0, task, patient_ids, global_event_df, output_dir))
litdata.index_parquet_dataset(str(output_dir))
return

num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers
batch_size = len(patient_ids) // num_workers + 1

# spwan is required for polars in multiprocessing, see https://docs.pola.rs/user-guide/misc/multiprocessing/#summary
ctx = multiprocessing.get_context("spawn")
queue = ctx.Queue()
args_list = [(
worker_id,
task,
pids,
global_event_df,
output_dir,
) for worker_id, pids in enumerate(itertools.batched(patient_ids, batch_size))]
with ctx.Pool(processes=num_workers, initializer=_task_transform_init, initargs=(queue,)) as pool:
result = pool.map_async(_task_transform_fn, args_list) # type: ignore
with tqdm(total=len(patient_ids)) as progress:
while not result.ready():
while not queue.empty():
progress.update(queue.get())

# remaining items
while not queue.empty():
progress.update(queue.get())
result.get() # ensure exceptions are raised

litdata.index_parquet_dataset(str(output_dir))
logger.info(f"Task transformation completed and saved to {output_dir}")
except Exception as e:
logger.error(f"Error during task transformation, cleaning up output directory: {output_dir}")
shutil.rmtree(output_dir)
raise e


def set_task(
self,
task: Optional[BaseTask] = None,
Expand Down Expand Up @@ -622,12 +738,7 @@ def set_task(
Raises:
AssertionError: If no default task is found and task is None.
"""
if not multiprocessing.current_process().name == "MainProcess":
logger.warning(
"set_task method accessed from a non-main process. This may lead to unexpected behavior.\n"
+ "Consider use __name__ == '__main__' guard when using multiprocessing."
)
return None # type: ignore
self._main_guard(self.set_task.__name__)

if task is None:
assert self.default_task is not None, "No default tasks found"
Expand Down Expand Up @@ -656,27 +767,12 @@ def set_task(
# Check if index.json exists to verify cache integrity, this
# is the standard file for litdata.StreamingDataset
if not (path / "index.json").exists():
global_event_df = task.pre_filter(self.global_event_df)
schema = pa.schema([("sample", pa.binary())])
with tempfile.TemporaryDirectory() as tmp_dir:
# Create Parquet file with samples
logger.info(f"Applying task transformations on data...")
with _ParquetWriter(f"{tmp_dir}/samples.parquet", schema) as writer:
# TODO: this can be further optimized.
patient_ids = (
global_event_df.select("patient_id")
.unique()
.collect(engine="streaming")
.to_series()
)
for patient_id in tqdm(patient_ids):
patient_df = global_event_df.filter(
pl.col("patient_id") == patient_id
).collect(engine="streaming")
patient = Patient(patient_id=patient_id, data_source=patient_df)
for sample in task(patient):
writer.append({"sample": pickle.dumps(sample)})
litdata.index_parquet_dataset(tmp_dir)
self._task_transform(
task,
Path(tmp_dir),
num_workers,
)

# Build processors and fit on the dataset
logger.info(f"Fitting processors on the dataset...")
Expand Down Expand Up @@ -718,3 +814,13 @@ def set_task(
dataset_name=self.dataset_name,
task_name=task.task_name,
)

def _main_guard(self, func_name: str):
"""Warn if method is accessed from a non-main process."""

if not multiprocessing.current_process().name == "MainProcess":
logger.warning(
f"{func_name} method accessed from a non-main process. This may lead to unexpected behavior.\n"
+ "Consider use __name__ == '__main__' guard when using multiprocessing."
)
exit(1)