From f58ec7d7bb5af0791032f7c1fc350b1f03855a95 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 26 Nov 2025 22:47:51 -0500 Subject: [PATCH 01/82] Polars fix bug of OOM on large table join --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ceedcdd0b..5eb0acaca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "urllib3~=2.5.0", "numpy~=1.26.4", "tqdm", - "polars~=1.31.0", + "polars~=1.35.2", "pandas~=2.3.1", "pandarallel~=1.6.5", "pydantic~=2.11.7", From c196a8bde3c8332249385f6013ad6549178f2257 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 26 Nov 2025 22:50:31 -0500 Subject: [PATCH 02/82] Fix type hint --- pyhealth/datasets/base_dataset.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 3390453ff..7bdf3f0b3 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -129,8 +129,8 @@ def __init__( self.root = root self.tables = tables self.dataset_name = dataset_name or self.__class__.__name__ - self.config = load_yaml_config(config_path) self.dev = dev + self.config = load_yaml_config(config_path) if config_path else None logger.info( f"Initializing {self.dataset_name} dataset from {self.root} (dev mode: {self.dev})" @@ -205,6 +205,8 @@ def load_table(self, table_name: str) -> pl.LazyFrame: ValueError: If the table is not found in the config. FileNotFoundError: If the CSV file for the table or join is not found. """ + assert self.config is not None, "Config must be provided to load tables" + if table_name not in self.config.tables: raise ValueError(f"Table {table_name} not found in config") @@ -243,7 +245,7 @@ def _to_lower(col_name: str) -> str: columns = join_cfg.columns how = join_cfg.how - df = df.join(join_df.select([join_key] + columns), on=join_key, how=how) + df = df.join(join_df.select([join_key] + columns), on=join_key, how=how) # type: ignore patient_id_col = table_cfg.patient_id timestamp_col = table_cfg.timestamp From 8602c0dfcbfe55aa922e5295a9a1492a8e027d65 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 26 Nov 2025 22:55:05 -0500 Subject: [PATCH 03/82] Add cache_dir --- pyhealth/datasets/base_dataset.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 7bdf3f0b3..6fb2f9d9a 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -6,6 +6,9 @@ from pathlib import Path from typing import Dict, Iterator, List, Optional from urllib.parse import urlparse, urlunparse +import json +import uuid +import platformdirs import polars as pl import requests @@ -112,6 +115,7 @@ def __init__( tables: List[str], dataset_name: Optional[str] = None, config_path: Optional[str] = None, + cache_dir: str | Path | None = None, dev: bool = False, ): """Initializes the BaseDataset. @@ -139,9 +143,33 @@ def __init__( self.global_event_df = self.load_data() # Cached attributes + self._cache_dir = cache_dir self._collected_global_event_df = None self._unique_patient_ids = None + @property + def cache_dir(self) -> Path: + """Returns the cache directory path. + Returns: + Path: The cache directory path. + """ + if self._cache_dir is None: + id_str = json.dumps( + { + "root": self.root, + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + }, + sort_keys=True, + ) + cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str( + uuid.uuid5(uuid.NAMESPACE_DNS, id_str) + ) + print(f"No cache_dir provided. Using default cache dir: {cache_dir}") + self._cache_dir = cache_dir + return Path(self._cache_dir) + @property def collected_global_event_df(self) -> pl.DataFrame: """Collects and returns the global event data frame. From 0ceeea906cae7375b65d1d49d5856e4b80793d30 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 26 Nov 2025 22:55:44 -0500 Subject: [PATCH 04/82] Remove to_lower as this is a no-op --- pyhealth/datasets/base_dataset.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 6fb2f9d9a..59636bea5 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -238,12 +238,6 @@ def load_table(self, table_name: str) -> pl.LazyFrame: if table_name not in self.config.tables: raise ValueError(f"Table {table_name} not found in config") - def _to_lower(col_name: str) -> str: - lower_name = col_name.lower() - if lower_name != col_name: - logger.warning("Renaming column %s to lowercase %s", col_name, lower_name) - return lower_name - table_cfg = self.config.tables[table_name] csv_path = f"{self.root}/{table_cfg.file_path}" csv_path = clean_path(csv_path) @@ -251,9 +245,6 @@ def _to_lower(col_name: str) -> str: logger.info(f"Scanning table: {table_name} from {csv_path}") df = scan_csv_gz_or_csv_tsv(csv_path) - # Convert column names to lowercase before calling preprocess_func - df = df.rename(_to_lower) - # Check if there is a preprocessing function for this table preprocess_func = getattr(self, f"preprocess_{table_name}", None) if preprocess_func is not None: @@ -268,7 +259,6 @@ def _to_lower(col_name: str) -> str: other_csv_path = clean_path(other_csv_path) logger.info(f"Joining with table: {other_csv_path}") join_df = scan_csv_gz_or_csv_tsv(other_csv_path) - join_df = join_df.rename(_to_lower) join_key = join_cfg.on columns = join_cfg.columns how = join_cfg.how @@ -311,7 +301,7 @@ def _to_lower(col_name: str) -> str: # Flatten attribute columns with event_type prefix attribute_columns = [ - pl.col(attr.lower()).alias(f"{table_name}/{attr}") for attr in attribute_cols + pl.col(attr).alias(f"{table_name}/{attr}") for attr in attribute_cols ] event_frame = df.select(base_columns + attribute_columns) From 996b35c31d46a3867f50a3ac65e54d6e8a4ef812 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 26 Nov 2025 23:45:14 -0500 Subject: [PATCH 05/82] Add caching behaviour --- pyhealth/datasets/base_dataset.py | 111 +++++++++++++++++------------- 1 file changed, 64 insertions(+), 47 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 59636bea5..73766a9ef 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -73,7 +73,12 @@ def scan_csv_gz_or_csv_tsv(path: str) -> pl.LazyFrame: def scan_file(file_path: str) -> pl.LazyFrame: separator = "\t" if ".tsv" in file_path else "," - return pl.scan_csv(file_path, separator=separator, infer_schema=False) + return pl.scan_csv( + file_path, + separator=separator, + infer_schema=False, + low_memory=True, + ) if path_exists(path): return scan_file(path) @@ -144,6 +149,7 @@ def __init__( # Cached attributes self._cache_dir = cache_dir + self._event_df_path = None self._collected_global_event_df = None self._unique_patient_ids = None @@ -166,51 +172,41 @@ def cache_dir(self) -> Path: cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str( uuid.uuid5(uuid.NAMESPACE_DNS, id_str) ) + cache_dir.mkdir(parents=True, exist_ok=True) print(f"No cache_dir provided. Using default cache dir: {cache_dir}") self._cache_dir = cache_dir return Path(self._cache_dir) @property - def collected_global_event_df(self) -> pl.DataFrame: - """Collects and returns the global event data frame. + def event_df(self) -> pl.LazyFrame: + """Returns the path to the cached event dataframe. Returns: - pl.DataFrame: The collected global event data frame. + Path: The path to the cached event dataframe. """ - if self._collected_global_event_df is None: - logger.info("Collecting global event dataframe...") - - # Collect the dataframe - with dev mode limiting if applicable - df = self.global_event_df - # TODO: dev doesn't seem to improve the speed / memory usage - if self.dev: - # Limit the number of patients in dev mode - logger.info("Dev mode enabled: limiting to 1000 patients") - limited_patients = df.select(pl.col("patient_id")).unique().limit(1000) - df = df.join(limited_patients, on="patient_id", how="inner") - - self._collected_global_event_df = df.collect() - - # Profile the Polars collect() operation (commented out by default) - # self._collected_global_event_df, profile = df.profile() - # profile = profile.with_columns([ - # (pl.col("end") - pl.col("start")).alias("duration"), - # ]) - # profile = profile.with_columns([ - # (pl.col("duration") / profile["duration"].sum() * 100).alias("percentage") - # ]) - # profile = profile.sort("duration", descending=True) - # with pl.Config() as cfg: - # cfg.set_tbl_rows(-1) - # cfg.set_fmt_str_lengths(200) - # print(profile) - - logger.info( - f"Collected dataframe with shape: {self._collected_global_event_df.shape}" - ) - - return self._collected_global_event_df + if self._event_df_path is None: + path = self.cache_dir / "event_df.parquet" + if not path.exists(): + df = self.load_data() + if self.dev: + logger.info("Dev mode enabled: limiting to 1000 patients") + limited_patients = df.select(pl.col("patient_id")).unique().limit(1000) + df = df.join(limited_patients, on="patient_id", how="inner") + + logger.info(f"Caching event dataframe to {path}...") + df.sort("patient_id").sink_parquet( + path, + compression="lz4", # use lz4 compression for faster read/write + row_group_size=8_192, + maintain_order=True, # Important for sorted writes + ) + self._event_df_path = path + return pl.scan_parquet( + self._event_df_path, + low_memory=True, + ).set_sorted("patient_id") # Guarantee sorted read, see sink_parquet above + def load_data(self) -> pl.LazyFrame: """Loads data from the specified tables. @@ -317,8 +313,10 @@ def unique_patient_ids(self) -> List[str]: """ if self._unique_patient_ids is None: self._unique_patient_ids = ( - self.collected_global_event_df.select("patient_id") + self.event_df + .select("patient_id") .unique() + .collect(engine="streaming") .to_series() .to_list() ) @@ -340,8 +338,13 @@ def get_patient(self, patient_id: str) -> Patient: assert ( patient_id in self.unique_patient_ids ), f"Patient {patient_id} not found in dataset" - df = self.collected_global_event_df.filter(pl.col("patient_id") == patient_id) - return Patient(patient_id=patient_id, data_source=df) + + data_source = ( + self.event_df + .filter(pl.col("patient_id") == patient_id) + .collect(engine="streaming") + ) + return Patient(patient_id=patient_id, data_source=data_source) def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: """Yields Patient objects for each unique patient in the dataset. @@ -350,20 +353,34 @@ def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: Iterator[Patient]: An iterator over Patient objects. """ if df is None: - df = self.collected_global_event_df - grouped = df.group_by("patient_id") + df = self.event_df + patient_ids = ( + df.select("patient_id") + .unique(maintain_order=True) + .collect(engine="streaming") + .to_series() + ) - for patient_id, patient_df in grouped: - patient_id = patient_id[0] + for patient_id in patient_ids: + patient_df = ( + df.filter(pl.col("patient_id") == patient_id) + .collect(engine="streaming") + ) yield Patient(patient_id=patient_id, data_source=patient_df) def stats(self) -> None: """Prints statistics about the dataset.""" - df = self.collected_global_event_df + stats = ( + self.event_df.select( + pl.len().alias("n_events"), + pl.col("patient_id").n_unique().alias("n_patients"), + ) + .collect(engine="streaming") + ) print(f"Dataset: {self.dataset_name}") print(f"Dev mode: {self.dev}") - print(f"Number of patients: {df['patient_id'].n_unique()}") - print(f"Number of events: {df.height}") + print(f"Number of patients: {stats['n_patients'][0]}") + print(f"Number of events: {stats['n_events'][0]}") @property def default_task(self) -> Optional[BaseTask]: From 6991f262459485545bd4d450dee7bb9aa5d75fe4 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 27 Nov 2025 00:01:24 -0500 Subject: [PATCH 06/82] Add test case --- tests/core/test_base_dataset.py | 126 ++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 tests/core/test_base_dataset.py diff --git a/tests/core/test_base_dataset.py b/tests/core/test_base_dataset.py new file mode 100644 index 000000000..5d91bbf43 --- /dev/null +++ b/tests/core/test_base_dataset.py @@ -0,0 +1,126 @@ +import tempfile +import unittest +from unittest.mock import patch + +import polars as pl + +from pyhealth.datasets.base_dataset import BaseDataset + + +class InMemoryDataset(BaseDataset): + """Dataset that bypasses file loading for tests.""" + + def __init__(self, data: pl.DataFrame, **kwargs): + self._data = data + super().__init__(**kwargs) + + def load_data(self) -> pl.LazyFrame: + return self._data.lazy() + + +class TestBaseDataset(unittest.TestCase): + def _single_row_data(self) -> pl.DataFrame: + return pl.DataFrame( + { + "patient_id": ["1"], + "event_type": ["test"], + "timestamp": [None], + "test/value": [0], + } + ) + + def test_cache_dir_varies_with_core_identifiers(self): + base_kwargs = dict( + tables=["table_a"], + dataset_name="CacheDataset", + dev=False, + ) + + with tempfile.TemporaryDirectory() as cache_root, patch( + "pyhealth.datasets.base_dataset.platformdirs.user_cache_dir", + return_value=cache_root, + ): + datasets = [ + InMemoryDataset( + data=self._single_row_data(), + root="/data/root_a", + **base_kwargs, + ), + InMemoryDataset( + data=self._single_row_data(), + root="/data/root_b", # different root + **base_kwargs, + ), + InMemoryDataset( + data=self._single_row_data(), + root="/data/root_a", + tables=["table_b"], # different tables + dataset_name="CacheDataset", + dev=False, + ), + InMemoryDataset( + data=self._single_row_data(), + root="/data/root_a", + tables=["table_a"], + dataset_name="OtherDataset", # different dataset name + dev=False, + ), + InMemoryDataset( + data=self._single_row_data(), + root="/data/root_a", + tables=["table_a"], + dataset_name="CacheDataset", + dev=True, # different dev flag + ), + ] + + cache_dirs = [ds.cache_dir for ds in datasets] + self.assertEqual( + len(cache_dirs), + len(set(cache_dirs)), + "cache_dir should change when root/tables/dataset_name/dev change", + ) + + def test_event_df_cache_is_physically_sorted(self): + unsorted_data = pl.DataFrame( + { + "patient_id": ["3", "1", "2", "1"], + "event_type": ["test"] * 4, + "timestamp": [None] * 4, + "test/value": [10, 20, 30, 40], + } + ) + original_order = unsorted_data["patient_id"].to_list() + + with tempfile.TemporaryDirectory() as cache_root, patch( + "pyhealth.datasets.base_dataset.platformdirs.user_cache_dir", + return_value=cache_root, + ): + dataset = InMemoryDataset( + data=unsorted_data, + root="/data/root_sort", + tables=["table_a"], + dataset_name="SortingDataset", + dev=False, + ) + + # Trigger caching of event_df.parquet + _ = dataset.event_df + cache_path = dataset.cache_dir / "event_df.parquet" + self.assertTrue(cache_path.exists(), "event_df cache should be created") + + cached_df = pl.read_parquet(cache_path) + cached_order = cached_df["patient_id"].to_list() + + self.assertNotEqual( + cached_order, original_order, "cache should not keep the unsorted order" + ) + self.assertEqual( + cached_order, + sorted(cached_order), + "cached event_df parquet must be sorted by patient_id", + ) + + +if __name__ == "__main__": + unittest.main() From a36a8192373bf1149a6ecc149b60f75054adee5c Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 27 Nov 2025 01:02:56 -0500 Subject: [PATCH 07/82] Add StreamingParquetWriter --- pyhealth/datasets/base_dataset.py | 138 +++++++++++++++----- tests/core/test_streaming_parquet_writer.py | 55 ++++++++ 2 files changed, 163 insertions(+), 30 deletions(-) create mode 100644 tests/core/test_streaming_parquet_writer.py diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 73766a9ef..17a5965f0 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -9,6 +9,10 @@ import json import uuid import platformdirs +import tempfile +import litdata +import pyarrow as pa +import pyarrow.parquet as pq import polars as pl import requests @@ -102,6 +106,80 @@ def scan_file(file_path: str) -> pl.LazyFrame: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") +class StreamingParquetWriter: + """ + Stream-write rows into a Parquet file in chunked (row-group) fashion. + + Usage: + writer = StreamingParquetWriter(Path("out.parquet"), schema, chunk_size=10000) + writer.append({"id": 1, "val": 3.14}) + writer.append({"id": 2, "val": 1.23}) + writer.close() + """ + + def __init__(self, path: Path, schema: pa.Schema, chunk_size: int = 50_000): + """ + Args: + path: output Parquet file path + schema: pyarrow.Schema (required) + chunk_size: flush buffer every N rows + """ + self.path = Path(path) + self.schema = schema + self.chunk_size = chunk_size + + if self.schema is None: + raise ValueError("schema must be provided — no automatic inference allowed.") + + self._writer: pq.ParquetWriter | None = None + self._buffer: list[dict] = [] + self._closed = False + + # -------------------------------------------------------------- + # Public API + # -------------------------------------------------------------- + def append(self, row: dict) -> None: + """Append a single row (a Python dict).""" + if self._closed: + raise RuntimeError("Cannot append to a closed StreamingParquetWriter") + + self._buffer.append(row) + if len(self._buffer) >= self.chunk_size: + self.flush() + + def flush(self) -> None: + """Flush buffered rows into a Parquet row-group.""" + if not self._buffer: + return + + # Convert list[dict] → Arrow RecordBatch + batch = pa.RecordBatch.from_pylist(self._buffer, schema=self.schema) + + # Lazy-initialize writer + if self._writer is None: + self._writer = pq.ParquetWriter(self.path, self.schema) + + self._writer.write_batch(batch) + self._buffer.clear() + + def close(self) -> None: + """Flush and close the Parquet writer.""" + if self._closed: + return + self.flush() + if self._writer is not None: + self._writer.close() + self._closed = True + + # -------------------------------------------------------------- + # Context manager support + # -------------------------------------------------------------- + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + class BaseDataset(ABC): """Abstract base class for all PyHealth datasets. @@ -395,7 +473,7 @@ def set_task( self, task: Optional[BaseTask] = None, num_workers: int = 1, - cache_dir: Optional[str] = None, + cache_dir: str | Path | None = None, cache_format: str = "parquet", input_processors: Optional[Dict[str, FeatureProcessor]] = None, output_processors: Optional[Dict[str, FeatureProcessor]] = None, @@ -409,8 +487,7 @@ def set_task( multi-threading may not speed up the task function. cache_dir (Optional[str]): Directory to cache processed samples. Default is None (no caching). - cache_format (str): Format for caching ('parquet' or 'pickle'). - Default is 'parquet'. + cache_format (str): Deprecated. Only "parquet" is supported now. input_processors (Optional[Dict[str, FeatureProcessor]]): Pre-fitted input processors. If provided, these will be used instead of creating new ones from task's input_schema. Defaults to None. @@ -428,38 +505,39 @@ def set_task( assert self.default_task is not None, "No default tasks found" task = self.default_task + if cache_format != "parquet": + logger.warning("Only 'parquet' cache_format is supported now. ") + logger.info( f"Setting task {task.task_name} for {self.dataset_name} base dataset..." ) - # Check for cached data if cache_dir is provided - samples = None - if cache_dir is not None: - cache_filename = f"{task.task_name}.{cache_format}" - cache_path = Path(cache_dir) / cache_filename - if cache_path.exists(): - logger.info(f"Loading cached samples from {cache_path}") - try: - if cache_format == "parquet": - # Load samples from parquet file - cached_df = pl.read_parquet(cache_path) - samples = [ - _restore_from_cache(row) for row in cached_df.to_dicts() - ] - elif cache_format == "pickle": - # Load samples from pickle file - with open(cache_path, "rb") as f: - samples = pickle.load(f) - else: - msg = f"Unsupported cache format: {cache_format}" - raise ValueError(msg) - logger.info(f"Loaded {len(samples)} cached samples") - except Exception as e: - logger.warning( - "Failed to load cached data: %s. Regenerating...", - e, + if cache_dir is None: + cache_dir = self.cache_dir / "tasks" / task.task_name + cache_dir.mkdir(parents=True, exist_ok=True) + + path = Path(cache_dir) + + # Check if index.json exists to verify cache integrity, this + # is the standard file for litdata.StreamingDataset + if not (path / "index.json").exists(): + event_df = task.pre_filter(self.event_df) + with tempfile.TemporaryDirectory() as tmp_dir: + logger.info(f"Applying task transformations on data...") + patient_ids = ( + event_df.select("patient_id") + .unique() + .collect(engine="streaming") + .to_series() + ) + for patient_id in tqdm(patient_ids): + patient_df = ( + event_df.filter(pl.col("patient_id") == patient_id) + .collect(engine="streaming") ) - samples = None + patient = Patient(patient_id=patient_id, data_source=patient_df) + # TODO: save this to temp cache + sample = task(patient) # Generate samples if not loaded from cache if samples is None: diff --git a/tests/core/test_streaming_parquet_writer.py b/tests/core/test_streaming_parquet_writer.py new file mode 100644 index 000000000..bda6f15e8 --- /dev/null +++ b/tests/core/test_streaming_parquet_writer.py @@ -0,0 +1,55 @@ +import tempfile +from pathlib import Path +import unittest + +import pyarrow as pa +import pyarrow.parquet as pq + +from pyhealth.datasets.base_dataset import StreamingParquetWriter +from tests.base import BaseTestCase + + +class TestStreamingParquetWriter(BaseTestCase): + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.schema = pa.schema( + [ + ("id", pa.int64()), + ("value", pa.string()), + ] + ) + self.output_path = Path(self.tmpdir.name) / "stream.parquet" + + def tearDown(self): + self.tmpdir.cleanup() + + def test_append_flush_close_and_context_manager(self): + rows = [ + {"id": 1, "value": "a"}, + {"id": 2, "value": "b"}, + {"id": 3, "value": "c"}, + {"id": 4, "value": "d"}, + ] + + with StreamingParquetWriter( + self.output_path, self.schema, chunk_size=2 + ) as writer: + # First two appends trigger an automatic flush due to chunk_size=2. + writer.append(rows[0]) + writer.append(rows[1]) + + # Flush again after adding a third row to ensure flushing appends + # rather than overwriting previous row groups. + writer.append(rows[2]) + writer.flush() + + # Leave data in the buffer to verify close() flushes it. + writer.append(rows[3]) + + # Context manager should have closed and flushed remaining buffered rows. + self.assertTrue(self.output_path.exists()) + + written_rows = pq.read_table(self.output_path).to_pylist() + + # Every append should be present as a distinct row in order. + self.assertEqual(written_rows, rows) From 4d95ce52af82024fb7207d2c9d147b1e11c38538 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 27 Nov 2025 01:14:30 -0500 Subject: [PATCH 08/82] write samples --- pyhealth/datasets/base_dataset.py | 145 ++++++++++-------------------- pyproject.toml | 1 + 2 files changed, 46 insertions(+), 100 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 17a5965f0..418e3cb91 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -78,8 +78,8 @@ def scan_csv_gz_or_csv_tsv(path: str) -> pl.LazyFrame: def scan_file(file_path: str) -> pl.LazyFrame: separator = "\t" if ".tsv" in file_path else "," return pl.scan_csv( - file_path, - separator=separator, + file_path, + separator=separator, infer_schema=False, low_memory=True, ) @@ -117,7 +117,7 @@ class StreamingParquetWriter: writer.close() """ - def __init__(self, path: Path, schema: pa.Schema, chunk_size: int = 50_000): + def __init__(self, path: Path | str, schema: pa.Schema, chunk_size: int = 8_192): """ Args: path: output Parquet file path @@ -129,7 +129,9 @@ def __init__(self, path: Path, schema: pa.Schema, chunk_size: int = 50_000): self.chunk_size = chunk_size if self.schema is None: - raise ValueError("schema must be provided — no automatic inference allowed.") + raise ValueError( + "schema must be provided — no automatic inference allowed." + ) self._writer: pq.ParquetWriter | None = None self._buffer: list[dict] = [] @@ -180,6 +182,7 @@ def __enter__(self): def __exit__(self, exc_type, exc, tb): self.close() + class BaseDataset(ABC): """Abstract base class for all PyHealth datasets. @@ -268,23 +271,27 @@ def event_df(self) -> pl.LazyFrame: df = self.load_data() if self.dev: logger.info("Dev mode enabled: limiting to 1000 patients") - limited_patients = df.select(pl.col("patient_id")).unique().limit(1000) + limited_patients = ( + df.select(pl.col("patient_id")).unique().limit(1000) + ) df = df.join(limited_patients, on="patient_id", how="inner") logger.info(f"Caching event dataframe to {path}...") df.sort("patient_id").sink_parquet( path, - compression="lz4", # use lz4 compression for faster read/write + compression="lz4", # use lz4 compression for faster read/write row_group_size=8_192, - maintain_order=True, # Important for sorted writes + maintain_order=True, # Important for sorted writes ) self._event_df_path = path return pl.scan_parquet( self._event_df_path, low_memory=True, - ).set_sorted("patient_id") # Guarantee sorted read, see sink_parquet above - + ).set_sorted( + "patient_id" + ) # Guarantee sorted read, see sink_parquet above + def load_data(self) -> pl.LazyFrame: """Loads data from the specified tables. @@ -337,7 +344,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame: columns = join_cfg.columns how = join_cfg.how - df = df.join(join_df.select([join_key] + columns), on=join_key, how=how) # type: ignore + df = df.join(join_df.select([join_key] + columns), on=join_key, how=how) # type: ignore patient_id_col = table_cfg.patient_id timestamp_col = table_cfg.timestamp @@ -391,8 +398,7 @@ def unique_patient_ids(self) -> List[str]: """ if self._unique_patient_ids is None: self._unique_patient_ids = ( - self.event_df - .select("patient_id") + self.event_df.select("patient_id") .unique() .collect(engine="streaming") .to_series() @@ -416,11 +422,9 @@ def get_patient(self, patient_id: str) -> Patient: assert ( patient_id in self.unique_patient_ids ), f"Patient {patient_id} not found in dataset" - - data_source = ( - self.event_df - .filter(pl.col("patient_id") == patient_id) - .collect(engine="streaming") + + data_source = self.event_df.filter(pl.col("patient_id") == patient_id).collect( + engine="streaming" ) return Patient(patient_id=patient_id, data_source=data_source) @@ -440,21 +444,17 @@ def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: ) for patient_id in patient_ids: - patient_df = ( - df.filter(pl.col("patient_id") == patient_id) - .collect(engine="streaming") + patient_df = df.filter(pl.col("patient_id") == patient_id).collect( + engine="streaming" ) yield Patient(patient_id=patient_id, data_source=patient_df) def stats(self) -> None: """Prints statistics about the dataset.""" - stats = ( - self.event_df.select( - pl.len().alias("n_events"), - pl.col("patient_id").n_unique().alias("n_patients"), - ) - .collect(engine="streaming") - ) + stats = self.event_df.select( + pl.len().alias("n_events"), + pl.col("patient_id").n_unique().alias("n_patients"), + ).collect(engine="streaming") print(f"Dataset: {self.dataset_name}") print(f"Dev mode: {self.dev}") print(f"Number of patients: {stats['n_patients'][0]}") @@ -517,85 +517,30 @@ def set_task( cache_dir.mkdir(parents=True, exist_ok=True) path = Path(cache_dir) - - # Check if index.json exists to verify cache integrity, this + + # Check if index.json exists to verify cache integrity, this # is the standard file for litdata.StreamingDataset if not (path / "index.json").exists(): event_df = task.pre_filter(self.event_df) + schema = pa.schema({"sample", pa.binary()}) with tempfile.TemporaryDirectory() as tmp_dir: - logger.info(f"Applying task transformations on data...") - patient_ids = ( - event_df.select("patient_id") - .unique() - .collect(engine="streaming") - .to_series() - ) - for patient_id in tqdm(patient_ids): - patient_df = ( - event_df.filter(pl.col("patient_id") == patient_id) + with StreamingParquetWriter(f"{tmp_dir}/samples.parquet", schema) as writer: + logger.info(f"Applying task transformations on data...") + + patient_ids = ( + event_df.select("patient_id") + .unique() .collect(engine="streaming") + .to_series() ) - patient = Patient(patient_id=patient_id, data_source=patient_df) - # TODO: save this to temp cache - sample = task(patient) - - # Generate samples if not loaded from cache - if samples is None: - logger.info(f"Generating samples with {num_workers} worker(s)...") - filtered_global_event_df = task.pre_filter(self.collected_global_event_df) - samples = [] - - if num_workers == 1: - # single-threading (by default) - for patient in tqdm( - self.iter_patients(filtered_global_event_df), - total=filtered_global_event_df["patient_id"].n_unique(), - desc=(f"Generating samples for {task.task_name} " "with 1 worker"), - smoothing=0, - ): - samples.extend(task(patient)) - else: - # multi-threading (not recommended) - logger.info( - f"Generating samples for {task.task_name} with " - f"{num_workers} workers" - ) - patients = list(self.iter_patients(filtered_global_event_df)) - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(task, patient) for patient in patients] - for future in tqdm( - as_completed(futures), - total=len(futures), - desc=( - f"Collecting samples for {task.task_name} " - f"from {num_workers} workers" - ), - ): - samples.extend(future.result()) - - # Cache the samples if cache_dir is provided - if cache_dir is not None: - cache_path = Path(cache_dir) / cache_filename - cache_path.parent.mkdir(parents=True, exist_ok=True) - logger.info(f"Caching samples to {cache_path}") - try: - if cache_format == "parquet": - # Save samples as parquet file - samples_for_cache = [ - _convert_for_cache(sample) for sample in samples - ] - samples_df = pl.DataFrame(samples_for_cache) - samples_df.write_parquet(cache_path) - elif cache_format == "pickle": - # Save samples as pickle file - with open(cache_path, "wb") as f: - pickle.dump(samples, f) - else: - msg = f"Unsupported cache format: {cache_format}" - raise ValueError(msg) - logger.info(f"Successfully cached {len(samples)} samples") - except Exception as e: - logger.warning(f"Failed to cache samples: {e}") + for patient_id in tqdm(patient_ids): + patient_df = event_df.filter( + pl.col("patient_id") == patient_id + ).collect(engine="streaming") + patient = Patient(patient_id=patient_id, data_source=patient_df) + for sample in task(patient): + writer.append({"sample": pickle.dumps(sample)}) + litdata.index_parquet_dataset(tmp_dir) sample_dataset = SampleDataset( samples, diff --git a/pyproject.toml b/pyproject.toml index 5eb0acaca..c2ee0c1c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "pandas~=2.3.1", "pandarallel~=1.6.5", "pydantic~=2.11.7", + "litdata~=0.2.58", ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] From 4f26c1d0ab8a93ffc44b2e4e3b56816ffb4c5e2e Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 27 Nov 2025 08:28:51 -0500 Subject: [PATCH 09/82] Add SampleBuilder --- pyhealth/datasets/sample_dataset.py | 181 +++++++++++++++++++++++++++- tests/core/test_sample_builder.py | 46 +++++++ 2 files changed, 221 insertions(+), 6 deletions(-) create mode 100644 tests/core/test_sample_builder.py diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 54c40420c..5c3b18cac 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Union, Type +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Type import inspect from torch.utils.data import Dataset @@ -8,6 +8,159 @@ from ..processors.base_processor import FeatureProcessor +class SampleBuilder: + """Utility to fit processors and transform samples without materializing a Dataset.""" + + def __init__( + self, + input_schema: Dict[ + str, + Union[ + str, + Type[FeatureProcessor], + FeatureProcessor, + Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]], + ], + ], + output_schema: Dict[ + str, + Union[ + str, + Type[FeatureProcessor], + FeatureProcessor, + Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]], + ], + ], + input_processors: Optional[Dict[str, FeatureProcessor]] = None, + output_processors: Optional[Dict[str, FeatureProcessor]] = None, + ) -> None: + self.input_schema = input_schema + self.output_schema = output_schema + self._input_processors = ( + input_processors if input_processors is not None else {} + ) + self._output_processors = ( + output_processors if output_processors is not None else {} + ) + self._patient_to_index: Dict[str, List[int]] = {} + self._record_to_index: Dict[str, List[int]] = {} + self._fitted = False + + @property + def input_processors(self) -> Dict[str, FeatureProcessor]: + if not self._fitted: + raise RuntimeError( + "SampleBuilder.fit must be called before accessing input_processors." + ) + return self._input_processors + + @property + def output_processors(self) -> Dict[str, FeatureProcessor]: + if not self._fitted: + raise RuntimeError( + "SampleBuilder.fit must be called before accessing output_processors." + ) + return self._output_processors + + @property + def patient_to_index(self) -> Dict[str, List[int]]: + if not self._fitted: + raise RuntimeError( + "SampleBuilder.fit must be called before accessing patient_to_index." + ) + return self._patient_to_index + + @property + def record_to_index(self) -> Dict[str, List[int]]: + if not self._fitted: + raise RuntimeError( + "SampleBuilder.fit must be called before accessing record_to_index." + ) + return self._record_to_index + + def _get_processor_instance(self, processor_spec): + """Instantiate a processor using the same resolution logic as SampleDataset.""" + if isinstance(processor_spec, tuple): + spec, kwargs = processor_spec + if isinstance(spec, str): + return get_processor(spec)(**kwargs) + if inspect.isclass(spec) and issubclass(spec, FeatureProcessor): + return spec(**kwargs) + raise ValueError( + "Processor spec in tuple must be either a string alias or a " + f"FeatureProcessor class, got {type(spec)}" + ) + if isinstance(processor_spec, str): + return get_processor(processor_spec)() + if inspect.isclass(processor_spec) and issubclass( + processor_spec, FeatureProcessor + ): + return processor_spec() + if isinstance(processor_spec, FeatureProcessor): + return processor_spec + raise ValueError( + "Processor spec must be either a string alias, a FeatureProcessor " + f"class, or a tuple (spec, kwargs_dict), got {type(processor_spec)}" + ) + + def _validate(self, samples: List[Dict[str, Any]]) -> None: + """Validate that provided samples contain the fields described in the schemas.""" + input_keys = set(self.input_schema.keys()) + output_keys = set(self.output_schema.keys()) + for sample in samples: + assert input_keys.issubset( + sample.keys() + ), "Input schema does not match samples." + assert output_keys.issubset( + sample.keys() + ), "Output schema does not match samples." + + def fit(self, samples: Iterator[Dict[str, Any]]) -> None: + """Fit processors and build index mappings from an iterator of samples.""" + sample_list = list(samples) + self._validate(sample_list) + + # Build index mappings + self._patient_to_index = {} + self._record_to_index = {} + for i, sample in enumerate(sample_list): + patient_id = sample.get("patient_id") + if patient_id is not None: + self._patient_to_index.setdefault(patient_id, []).append(i) + record_id = sample.get("record_id", sample.get("visit_id")) + if record_id is not None: + self._record_to_index.setdefault(record_id, []).append(i) + + # Fit processors if they were not provided + if not self._input_processors: + for key, spec in self.input_schema.items(): + processor = self._get_processor_instance(spec) + processor.fit(sample_list, key) + self._input_processors[key] = processor + if not self._output_processors: + for key, spec in self.output_schema.items(): + processor = self._get_processor_instance(spec) + processor.fit(sample_list, key) + self._output_processors[key] = processor + + self._fitted = True + + def transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample dictionary using the fitted processors.""" + if not self._fitted: + raise RuntimeError("SampleBuilder.fit must be called before transform().") + + transformed: Dict[str, Any] = {} + for key, value in sample.items(): + if key in self._input_processors: + transformed[key] = self._input_processors[key].process(value) + elif key in self._output_processors: + transformed[key] = self._output_processors[key].process(value) + else: + transformed[key] = value + return transformed + + class SampleDataset(Dataset): """Sample dataset class for handling and processing data samples. @@ -24,8 +177,24 @@ class SampleDataset(Dataset): def __init__( self, samples: List[Dict], - input_schema: Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]], - output_schema: Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]], + input_schema: Dict[ + str, + Union[ + str, + Type[FeatureProcessor], + FeatureProcessor, + Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]], + ], + ], + output_schema: Dict[ + str, + Union[ + str, + Type[FeatureProcessor], + FeatureProcessor, + Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]], + ], + ], dataset_name: Optional[str] = None, task_name: Optional[str] = None, input_processors: Optional[Dict[str, FeatureProcessor]] = None, @@ -130,9 +299,9 @@ def validate(self) -> None: output_keys = set(self.output_schema.keys()) for s in self.samples: assert input_keys.issubset(s.keys()), "Input schema does not match samples." - assert output_keys.issubset(s.keys()), ( - "Output schema does not match samples." - ) + assert output_keys.issubset( + s.keys() + ), "Output schema does not match samples." return def build(self) -> None: diff --git a/tests/core/test_sample_builder.py b/tests/core/test_sample_builder.py new file mode 100644 index 000000000..33cd0bbca --- /dev/null +++ b/tests/core/test_sample_builder.py @@ -0,0 +1,46 @@ +import unittest + +from pyhealth.datasets.sample_dataset import SampleBuilder + + +class TestSampleBuilder(unittest.TestCase): + def setUp(self): + self.samples = [ + {"patient_id": "p1", "record_id": "r1", "feature": "a", "label": 1}, + {"patient_id": "p1", "record_id": "r2", "feature": "b", "label": 0}, + {"patient_id": "p2", "record_id": "r3", "feature": "c", "label": 1}, + ] + self.input_schema = {"feature": "raw"} + self.output_schema = {"label": "raw"} + + def test_fit_and_transform(self): + builder = SampleBuilder( + input_schema=self.input_schema, output_schema=self.output_schema + ) + + with self.assertRaises(RuntimeError): + _ = builder.input_processors # Access before fit should fail + + builder.fit(iter(self.samples)) + + self.assertIn("feature", builder.input_processors) + self.assertIn("label", builder.output_processors) + + self.assertEqual(builder.patient_to_index["p1"], [0, 1]) + self.assertEqual(builder.record_to_index["r3"], [2]) + + transformed = builder.transform(self.samples[0]) + self.assertEqual(transformed["feature"], "a") + self.assertEqual(transformed["label"], 1) + self.assertEqual(transformed["patient_id"], "p1") + + def test_transform_requires_fit(self): + builder = SampleBuilder( + input_schema=self.input_schema, output_schema=self.output_schema + ) + with self.assertRaises(RuntimeError): + builder.transform(self.samples[0]) + + +if __name__ == "__main__": + unittest.main() From 2ad809dbc1ab278f46ac388f20ab2c8f06b21aa9 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 27 Nov 2025 08:29:50 -0500 Subject: [PATCH 10/82] Fix Mimic4 --- pyhealth/datasets/__init__.py | 2 +- pyhealth/datasets/base_dataset.py | 57 ++++++++++++++++++++----------- pyhealth/datasets/mimic4.py | 39 +++++++++------------ 3 files changed, 55 insertions(+), 43 deletions(-) diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index ced02afd7..df2598f02 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -59,7 +59,7 @@ def __init__(self, *args, **kwargs): from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset -from .sample_dataset import SampleDataset +from .sample_dataset import SampleBuilder, SampleDataset from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset from .bmd_hs import BMDHSDataset diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 418e3cb91..f85e9a1bc 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -4,16 +4,17 @@ from abc import ABC from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from typing import Dict, Iterator, List, Optional +from typing import Dict, Iterator, List, Optional, Any from urllib.parse import urlparse, urlunparse import json import uuid import platformdirs import tempfile + import litdata +from litdata.streaming.item_loader import ParquetLoader import pyarrow as pa import pyarrow.parquet as pq - import polars as pl import requests from tqdm import tqdm @@ -22,7 +23,7 @@ from ..tasks import BaseTask from ..processors.base_processor import FeatureProcessor from .configs import load_yaml_config -from .sample_dataset import SampleDataset +from .sample_dataset import SampleDataset, SampleBuilder from .utils import _convert_for_cache, _restore_from_cache logger = logging.getLogger(__name__) @@ -105,6 +106,8 @@ def scan_file(file_path: str) -> pl.LazyFrame: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") +def unpickle_sample(sample_bytes: bytes) -> dict[str, Any]: + return pickle.loads(sample_bytes) class StreamingParquetWriter: """ @@ -226,12 +229,9 @@ def __init__( f"Initializing {self.dataset_name} dataset from {self.root} (dev mode: {self.dev})" ) - self.global_event_df = self.load_data() - # Cached attributes self._cache_dir = cache_dir self._event_df_path = None - self._collected_global_event_df = None self._unique_patient_ids = None @property @@ -524,6 +524,7 @@ def set_task( event_df = task.pre_filter(self.event_df) schema = pa.schema({"sample", pa.binary()}) with tempfile.TemporaryDirectory() as tmp_dir: + tmp_dir = "./test_task_cache" # For debugging purposes, keep the temp dir with StreamingParquetWriter(f"{tmp_dir}/samples.parquet", schema) as writer: logger.info(f"Applying task transformations on data...") @@ -541,16 +542,34 @@ def set_task( for sample in task(patient): writer.append({"sample": pickle.dumps(sample)}) litdata.index_parquet_dataset(tmp_dir) - - sample_dataset = SampleDataset( - samples, - input_schema=task.input_schema, - output_schema=task.output_schema, - dataset_name=self.dataset_name, - task_name=task, - input_processors=input_processors, - output_processors=output_processors, - ) - - logger.info(f"Generated {len(samples)} samples for task {task.task_name}") - return sample_dataset + # dataset = litdata.StreamingDataset( + # tmp_dir, + # transform=unpickle_sample, + # item_loader=ParquetLoader(), + # ) + # builder = SampleBuilder( + # input_schema=task.input_schema, # type: ignore + # output_schema=task.output_schema, # type: ignore + # input_processors=input_processors, + # output_processors=output_processors, + # ) + # builder.fit(iter(dataset)) + # litdata.optimize( + # fn=lambda x: builder.transform(x), + # inputs=Streadataset, + + # ) + + + # sample_dataset = SampleDataset( + # samples, + # input_schema=task.input_schema, + # output_schema=task.output_schema, + # dataset_name=self.dataset_name, + # task_name=task, + # input_processors=input_processors, + # output_processors=output_processors, + # ) + + # logger.info(f"Generated {len(samples)} samples for task {task.task_name}") + # return sample_dataset diff --git a/pyhealth/datasets/mimic4.py b/pyhealth/datasets/mimic4.py index 05321dedb..e3d2340c1 100644 --- a/pyhealth/datasets/mimic4.py +++ b/pyhealth/datasets/mimic4.py @@ -1,7 +1,7 @@ import logging import os import warnings -from typing import Dict, List, Optional +from typing import Dict, List, Optional, override import pandas as pd import polars as pl @@ -20,7 +20,7 @@ def log_memory_usage(tag=""): """Log current memory usage if psutil is available.""" if HAS_PSUTIL: - process = psutil.Process(os.getpid()) + process = psutil.Process(os.getpid()) # type: ignore mem_info = process.memory_info() logger.info(f"Memory usage {tag}: {mem_info.rss / (1024 * 1024):.1f} MB") else: @@ -214,16 +214,6 @@ def __init__( ): log_memory_usage("Starting MIMIC4Dataset init") - # Initialize child datasets - self.dataset_name = dataset_name - self.sub_datasets = {} - self.root = None - self.tables = None - self.config = None - # Dev flag is only used in the MIMIC4Dataset class - # to ensure the same set of patients are used for all sub-datasets. - self.dev = dev - # We need at least one root directory if not any([ehr_root, note_root, cxr_root]): raise ValueError("At least one root directory must be provided") @@ -233,6 +223,17 @@ def __init__( note_tables = note_tables or [] cxr_tables = cxr_tables or [] + super().__init__( + root=f"{ehr_root}|{note_root}|{cxr_root}", + tables=ehr_tables + note_tables + cxr_tables, + dataset_name=dataset_name, + config_path=None, + dev=dev, + ) + + # Initialize child datasets + self.sub_datasets: dict[str, BaseDataset] = {} + # Initialize EHR dataset if root is provided if ehr_root: logger.info(f"Initializing MIMIC4EHRDataset with tables: {ehr_tables} (dev mode: {dev})") @@ -263,18 +264,10 @@ def __init__( ) log_memory_usage("After CXR dataset initialization") - # Combine data from all sub-datasets - log_memory_usage("Before combining data") - self.global_event_df = self._combine_data() - log_memory_usage("After combining data") - - # Cache attributes - self._collected_global_event_df = None - self._unique_patient_ids = None - log_memory_usage("Completed MIMIC4Dataset init") - def _combine_data(self) -> pl.LazyFrame: + @override + def load_data(self) -> pl.LazyFrame: """ Combines data from all initialized sub-datasets into a unified global event dataframe. @@ -286,7 +279,7 @@ def _combine_data(self) -> pl.LazyFrame: # Collect global event dataframes from all sub-datasets for dataset_type, dataset in self.sub_datasets.items(): logger.info(f"Combining data from {dataset_type} dataset") - frames.append(dataset.global_event_df) + frames.append(dataset.load_data()) # Concatenate all frames logger.info("Creating combined dataframe") From 45ae3437c3c57a87821afd247cbe24415df6f3c0 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 5 Dec 2025 17:18:52 -0500 Subject: [PATCH 11/82] fix incorrect dev mode --- pyhealth/datasets/base_dataset.py | 63 +++++++++++++++++-------------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index f85e9a1bc..ca608cf36 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -106,8 +106,10 @@ def scan_file(file_path: str) -> pl.LazyFrame: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") -def unpickle_sample(sample_bytes: bytes) -> dict[str, Any]: - return pickle.loads(sample_bytes) + +def unpickle_sample(sample_bytes: dict[str, bytes]) -> dict[str, Any]: + return pickle.loads(sample_bytes["sample"]) + class StreamingParquetWriter: """ @@ -259,7 +261,7 @@ def cache_dir(self) -> Path: return Path(self._cache_dir) @property - def event_df(self) -> pl.LazyFrame: + def global_event_df(self) -> pl.LazyFrame: """Returns the path to the cached event dataframe. Returns: @@ -272,7 +274,9 @@ def event_df(self) -> pl.LazyFrame: if self.dev: logger.info("Dev mode enabled: limiting to 1000 patients") limited_patients = ( - df.select(pl.col("patient_id")).unique().limit(1000) + df.select(pl.col("patient_id").shuffle(seed=0)) + .unique() + .limit(1000) ) df = df.join(limited_patients, on="patient_id", how="inner") @@ -398,7 +402,7 @@ def unique_patient_ids(self) -> List[str]: """ if self._unique_patient_ids is None: self._unique_patient_ids = ( - self.event_df.select("patient_id") + self.global_event_df.select("patient_id") .unique() .collect(engine="streaming") .to_series() @@ -423,9 +427,9 @@ def get_patient(self, patient_id: str) -> Patient: patient_id in self.unique_patient_ids ), f"Patient {patient_id} not found in dataset" - data_source = self.event_df.filter(pl.col("patient_id") == patient_id).collect( - engine="streaming" - ) + data_source = self.global_event_df.filter( + pl.col("patient_id") == patient_id + ).collect(engine="streaming") return Patient(patient_id=patient_id, data_source=data_source) def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: @@ -435,7 +439,7 @@ def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: Iterator[Patient]: An iterator over Patient objects. """ if df is None: - df = self.event_df + df = self.global_event_df patient_ids = ( df.select("patient_id") .unique(maintain_order=True) @@ -451,7 +455,7 @@ def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: def stats(self) -> None: """Prints statistics about the dataset.""" - stats = self.event_df.select( + stats = self.global_event_df.select( pl.len().alias("n_events"), pl.col("patient_id").n_unique().alias("n_patients"), ).collect(engine="streaming") @@ -521,13 +525,17 @@ def set_task( # Check if index.json exists to verify cache integrity, this # is the standard file for litdata.StreamingDataset if not (path / "index.json").exists(): - event_df = task.pre_filter(self.event_df) - schema = pa.schema({"sample", pa.binary()}) + event_df = task.pre_filter(self.global_event_df) + schema = pa.schema([("sample", pa.binary())]) with tempfile.TemporaryDirectory() as tmp_dir: - tmp_dir = "./test_task_cache" # For debugging purposes, keep the temp dir - with StreamingParquetWriter(f"{tmp_dir}/samples.parquet", schema) as writer: + tmp_dir = ( + "./test_task_cache" # For debugging purposes, keep the temp dir + ) + with StreamingParquetWriter( + f"{tmp_dir}/samples.parquet", schema + ) as writer: logger.info(f"Applying task transformations on data...") - + patient_ids = ( event_df.select("patient_id") .unique() @@ -542,25 +550,24 @@ def set_task( for sample in task(patient): writer.append({"sample": pickle.dumps(sample)}) litdata.index_parquet_dataset(tmp_dir) - # dataset = litdata.StreamingDataset( - # tmp_dir, - # transform=unpickle_sample, - # item_loader=ParquetLoader(), - # ) - # builder = SampleBuilder( - # input_schema=task.input_schema, # type: ignore - # output_schema=task.output_schema, # type: ignore - # input_processors=input_processors, - # output_processors=output_processors, - # ) - # builder.fit(iter(dataset)) + dataset = litdata.StreamingDataset( + tmp_dir, + item_loader=ParquetLoader(), + ) + builder = SampleBuilder( + input_schema=task.input_schema, # type: ignore + output_schema=task.output_schema, # type: ignore + input_processors=input_processors, + output_processors=output_processors, + ) + builder.fit(map(unpickle_sample, iter(dataset))) + return dataset, builder # litdata.optimize( # fn=lambda x: builder.transform(x), # inputs=Streadataset, # ) - # sample_dataset = SampleDataset( # samples, # input_schema=task.input_schema, From b4949ec9cf144a107da411f0ef57a3f1541c7096 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 5 Dec 2025 17:19:17 -0500 Subject: [PATCH 12/82] change fit to take Iterator --- pyhealth/processors/base_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyhealth/processors/base_processor.py b/pyhealth/processors/base_processor.py index 050cb5357..d207f0220 100644 --- a/pyhealth/processors/base_processor.py +++ b/pyhealth/processors/base_processor.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Iterator class Processor(ABC): @@ -33,7 +33,7 @@ class FeatureProcessor(Processor): Example: Tokenization, image loading, normalization. """ - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterator[Dict[str, Any]], field: str) -> None: """Fit the processor to the samples. Args: From 0494bca8d533b2e488946ace22ee328e1e1af5af Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 5 Dec 2025 17:19:35 -0500 Subject: [PATCH 13/82] update test --- tests/core/test_base_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_base_dataset.py b/tests/core/test_base_dataset.py index 5d91bbf43..a246fa518 100644 --- a/tests/core/test_base_dataset.py +++ b/tests/core/test_base_dataset.py @@ -105,7 +105,7 @@ def test_event_df_cache_is_physically_sorted(self): ) # Trigger caching of event_df.parquet - _ = dataset.event_df + _ = dataset.global_event_df cache_path = dataset.cache_dir / "event_df.parquet" self.assertTrue(cache_path.exists(), "event_df cache should be created") From e63d5003c41271007e12abfe3f18c25bc240d21d Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 5 Dec 2025 17:21:19 -0500 Subject: [PATCH 14/82] rename --- pyhealth/datasets/base_dataset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index ca608cf36..d0c126695 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -233,7 +233,7 @@ def __init__( # Cached attributes self._cache_dir = cache_dir - self._event_df_path = None + self._global_event_df = None self._unique_patient_ids = None @property @@ -267,8 +267,8 @@ def global_event_df(self) -> pl.LazyFrame: Returns: Path: The path to the cached event dataframe. """ - if self._event_df_path is None: - path = self.cache_dir / "event_df.parquet" + if self._global_event_df is None: + path = self.cache_dir / "global_event_df.parquet" if not path.exists(): df = self.load_data() if self.dev: @@ -287,10 +287,10 @@ def global_event_df(self) -> pl.LazyFrame: row_group_size=8_192, maintain_order=True, # Important for sorted writes ) - self._event_df_path = path + self._global_event_df = path return pl.scan_parquet( - self._event_df_path, + self._global_event_df, low_memory=True, ).set_sorted( "patient_id" From 525c526d806652d5224fcf56b825024323383a6a Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 5 Dec 2025 17:52:45 -0500 Subject: [PATCH 15/82] update fit to use Iterable --- pyhealth/processors/base_processor.py | 4 ++-- pyhealth/processors/label_processor.py | 8 ++++---- .../processors/nested_sequence_processor.py | 6 +++--- pyhealth/processors/stagenet_processor.py | 18 +++++++++--------- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pyhealth/processors/base_processor.py b/pyhealth/processors/base_processor.py index d207f0220..823fcafec 100644 --- a/pyhealth/processors/base_processor.py +++ b/pyhealth/processors/base_processor.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Iterator +from typing import Any, Dict, List, Iterable class Processor(ABC): @@ -33,7 +33,7 @@ class FeatureProcessor(Processor): Example: Tokenization, image loading, normalization. """ - def fit(self, samples: Iterator[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: """Fit the processor to the samples. Args: diff --git a/pyhealth/processors/label_processor.py b/pyhealth/processors/label_processor.py index ad2df1897..ff32dabf8 100644 --- a/pyhealth/processors/label_processor.py +++ b/pyhealth/processors/label_processor.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List +from typing import Any, Dict, List, Iterable import torch @@ -19,7 +19,7 @@ def __init__(self): super().__init__() self.label_vocab: Dict[Any, int] = {} - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: all_labels = set([sample[field] for sample in samples]) if len(all_labels) != 2: raise ValueError(f"Expected 2 unique labels, got {len(all_labels)}") @@ -54,7 +54,7 @@ def __init__(self): super().__init__() self.label_vocab: Dict[Any, int] = {} - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: all_labels = set([sample[field] for sample in samples]) num_classes = len(all_labels) if all_labels == set(range(num_classes)): @@ -89,7 +89,7 @@ def __init__(self): super().__init__() self.label_vocab: Dict[Any, int] = {} - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: all_labels = set() for sample in samples: for label in sample[field]: diff --git a/pyhealth/processors/nested_sequence_processor.py b/pyhealth/processors/nested_sequence_processor.py index bf7ed2055..ab7723b71 100644 --- a/pyhealth/processors/nested_sequence_processor.py +++ b/pyhealth/processors/nested_sequence_processor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Iterable import torch @@ -51,7 +51,7 @@ def __init__(self, padding: int = 0): self._max_inner_len = 1 # Maximum length of inner sequences self._padding = padding # Additional padding beyond observed max - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: """Build vocabulary and determine maximum inner sequence length. Args: @@ -183,7 +183,7 @@ def __init__(self, forward_fill: bool = True, padding: int = 0): self.forward_fill = forward_fill self._padding = padding # Additional padding beyond observed max - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: """Determine maximum inner sequence length. Args: diff --git a/pyhealth/processors/stagenet_processor.py b/pyhealth/processors/stagenet_processor.py index cbbafac94..3ea6a8b04 100644 --- a/pyhealth/processors/stagenet_processor.py +++ b/pyhealth/processors/stagenet_processor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Iterable import torch @@ -61,7 +61,7 @@ def __init__(self, padding: int = 0): self._max_nested_len = None # Max inner sequence length for nested codes self._padding = padding # Additional padding beyond observed max - def fit(self, samples: List[Dict], key: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: """Build vocabulary and determine input structure. Args: @@ -70,9 +70,9 @@ def fit(self, samples: List[Dict], key: str) -> None: """ # Examine first non-None sample to determine structure for sample in samples: - if key in sample and sample[key] is not None: + if field in sample and sample[field] is not None: # Unpack tuple: (time, values) - time_data, value_data = sample[key] + time_data, value_data = sample[field] # Determine nesting level for codes if isinstance(value_data, list) and len(value_data) > 0: @@ -90,9 +90,9 @@ def fit(self, samples: List[Dict], key: str) -> None: # Build vocabulary for codes and find max nested length max_inner_len = 0 for sample in samples: - if key in sample and sample[key] is not None: + if field in sample and sample[field] is not None: # Unpack tuple: (time, values) - time_data, value_data = sample[key] + time_data, value_data = sample[field] if self._is_nested: # Nested codes @@ -256,7 +256,7 @@ def __init__(self): self._size = None # Feature dimension (set during fit) self._is_nested = None - def fit(self, samples: List[Dict], key: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: """Determine input structure. Args: @@ -265,9 +265,9 @@ def fit(self, samples: List[Dict], key: str) -> None: """ # Examine first non-None sample to determine structure for sample in samples: - if key in sample and sample[key] is not None: + if field in sample and sample[field] is not None: # Unpack tuple: (time, values) - time_data, value_data = sample[key] + time_data, value_data = sample[field] # Determine nesting level for numerics if isinstance(value_data, list) and len(value_data) > 0: From 6caa917e2c802f2df6c4acdc8c3111eeecbe209c Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 5 Dec 2025 17:57:37 -0500 Subject: [PATCH 16/82] Fix SampleBuilder --- pyhealth/datasets/sample_dataset.py | 41 +++++++++-------------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 5c3b18cac..bd90bc5d7 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Type +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union, Type import inspect from torch.utils.data import Dataset @@ -13,24 +13,8 @@ class SampleBuilder: def __init__( self, - input_schema: Dict[ - str, - Union[ - str, - Type[FeatureProcessor], - FeatureProcessor, - Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]], - ], - ], - output_schema: Dict[ - str, - Union[ - str, - Type[FeatureProcessor], - FeatureProcessor, - Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]], - ], - ], + input_schema: Dict[str, Any], + output_schema: Dict[str, Any], input_processors: Optional[Dict[str, FeatureProcessor]] = None, output_processors: Optional[Dict[str, FeatureProcessor]] = None, ) -> None: @@ -103,8 +87,12 @@ def _get_processor_instance(self, processor_spec): f"class, or a tuple (spec, kwargs_dict), got {type(processor_spec)}" ) - def _validate(self, samples: List[Dict[str, Any]]) -> None: - """Validate that provided samples contain the fields described in the schemas.""" + def fit( + self, + samples: Iterable[Dict[str, Any]], + ) -> None: + """Fit processors and build index mappings from an iterator of samples.""" + # Validate the samples input_keys = set(self.input_schema.keys()) output_keys = set(self.output_schema.keys()) for sample in samples: @@ -115,15 +103,10 @@ def _validate(self, samples: List[Dict[str, Any]]) -> None: sample.keys() ), "Output schema does not match samples." - def fit(self, samples: Iterator[Dict[str, Any]]) -> None: - """Fit processors and build index mappings from an iterator of samples.""" - sample_list = list(samples) - self._validate(sample_list) - # Build index mappings self._patient_to_index = {} self._record_to_index = {} - for i, sample in enumerate(sample_list): + for i, sample in enumerate(samples): patient_id = sample.get("patient_id") if patient_id is not None: self._patient_to_index.setdefault(patient_id, []).append(i) @@ -135,12 +118,12 @@ def fit(self, samples: Iterator[Dict[str, Any]]) -> None: if not self._input_processors: for key, spec in self.input_schema.items(): processor = self._get_processor_instance(spec) - processor.fit(sample_list, key) + processor.fit(samples, key) self._input_processors[key] = processor if not self._output_processors: for key, spec in self.output_schema.items(): processor = self._get_processor_instance(spec) - processor.fit(sample_list, key) + processor.fit(samples, key) self._output_processors[key] = processor self._fitted = True From 8a40d3b2d7afc97ed91c7fb15986683c1099f603 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 5 Dec 2025 18:16:10 -0500 Subject: [PATCH 17/82] Fix tsv test --- tests/core/test_tsv_load.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/core/test_tsv_load.py b/tests/core/test_tsv_load.py index 779b928b6..25c87f1db 100644 --- a/tests/core/test_tsv_load.py +++ b/tests/core/test_tsv_load.py @@ -156,7 +156,7 @@ def test_tsv_load(self): self.assertIsNotNone(dataset.config) # Test that we can collect the dataframe - collected_df = dataset.collected_global_event_df + collected_df = dataset.global_event_df.collect() self.assertIsInstance(collected_df, pl.DataFrame) self.assertGreater( collected_df.height, 0, "Dataset should have at least one row" @@ -201,7 +201,7 @@ def test_tsv_file_detection(self): dev=False, ) - collected_df = dataset.collected_global_event_df + collected_df = dataset.global_event_df.collect() # Verify we have the expected number of patients self.assertEqual(collected_df["patient_id"].n_unique(), 5) @@ -231,7 +231,7 @@ def test_multiple_tsv_tables(self): dev=False, ) - collected_df = dataset.collected_global_event_df + collected_df = dataset.global_event_df.collect() # Should have data from both tables self.assertGreater(collected_df.height, 5) # More than just patients table From 504aaa2cd458a98cd487d8ce78a8cf0932fccb58 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 5 Dec 2025 18:17:45 -0500 Subject: [PATCH 18/82] Fix base dataset test --- tests/core/test_base_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/core/test_base_dataset.py b/tests/core/test_base_dataset.py index a246fa518..2efd49d46 100644 --- a/tests/core/test_base_dataset.py +++ b/tests/core/test_base_dataset.py @@ -104,10 +104,10 @@ def test_event_df_cache_is_physically_sorted(self): dev=False, ) - # Trigger caching of event_df.parquet + # Trigger caching of global_event_df.parquet _ = dataset.global_event_df - cache_path = dataset.cache_dir / "event_df.parquet" - self.assertTrue(cache_path.exists(), "event_df cache should be created") + cache_path = dataset.cache_dir / "global_event_df.parquet" + self.assertTrue(cache_path.exists(), "global_event_df cache should be created") cached_df = pl.read_parquet(cache_path) cached_order = cached_df["patient_id"].to_list() @@ -118,7 +118,7 @@ def test_event_df_cache_is_physically_sorted(self): self.assertEqual( cached_order, sorted(cached_order), - "cached event_df parquet must be sorted by patient_id", + "cached global_event_df parquet must be sorted by patient_id", ) From 6c9363ad1d23660ea9836f9290bd7af5397f7382 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 5 Dec 2025 18:35:52 -0500 Subject: [PATCH 19/82] cache processed data --- pyhealth/datasets/base_dataset.py | 44 ++++++++++++++++++------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index d0c126695..1d85aeebe 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -107,10 +107,6 @@ def scan_file(file_path: str) -> pl.LazyFrame: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") -def unpickle_sample(sample_bytes: dict[str, bytes]) -> dict[str, Any]: - return pickle.loads(sample_bytes["sample"]) - - class StreamingParquetWriter: """ Stream-write rows into a Parquet file in chunked (row-group) fashion. @@ -525,34 +521,36 @@ def set_task( # Check if index.json exists to verify cache integrity, this # is the standard file for litdata.StreamingDataset if not (path / "index.json").exists(): - event_df = task.pre_filter(self.global_event_df) + global_event_df = task.pre_filter(self.global_event_df) schema = pa.schema([("sample", pa.binary())]) with tempfile.TemporaryDirectory() as tmp_dir: - tmp_dir = ( - "./test_task_cache" # For debugging purposes, keep the temp dir - ) + # Create Parquet file with samples + logger.info(f"Applying task transformations on data...") with StreamingParquetWriter( f"{tmp_dir}/samples.parquet", schema ) as writer: - logger.info(f"Applying task transformations on data...") - + # TODO: this can be further optimized. patient_ids = ( - event_df.select("patient_id") + global_event_df.select("patient_id") .unique() .collect(engine="streaming") .to_series() ) for patient_id in tqdm(patient_ids): - patient_df = event_df.filter( + patient_df = global_event_df.filter( pl.col("patient_id") == patient_id ).collect(engine="streaming") patient = Patient(patient_id=patient_id, data_source=patient_df) for sample in task(patient): writer.append({"sample": pickle.dumps(sample)}) litdata.index_parquet_dataset(tmp_dir) + + # Build processors and fit on the dataset + logger.info(f"Fitting processors on the dataset...") dataset = litdata.StreamingDataset( tmp_dir, item_loader=ParquetLoader(), + transform=lambda x: pickle.loads(x["sample"]), ) builder = SampleBuilder( input_schema=task.input_schema, # type: ignore @@ -560,13 +558,23 @@ def set_task( input_processors=input_processors, output_processors=output_processors, ) - builder.fit(map(unpickle_sample, iter(dataset))) - return dataset, builder - # litdata.optimize( - # fn=lambda x: builder.transform(x), - # inputs=Streadataset, + builder.fit(dataset) - # ) + # Apply processors and save final samples to cache_dir + logger.info(f"Processing samples and saving to {path}...") + dataset = litdata.StreamingDataset( + tmp_dir, + item_loader=ParquetLoader(), + transform=builder.transform, + ) + litdata.optimize( + fn=lambda x: x, + inputs=litdata.StreamingDataLoader(dataset), + output_dir=str(path), + chunk_bytes="64MB", + num_workers=num_workers, + ) + logger.info(f"Cached processed samples to {path}") # sample_dataset = SampleDataset( # samples, From a42c131424b856ab60cc28bc6e50d7c718451f2a Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 6 Dec 2025 11:15:24 -0500 Subject: [PATCH 20/82] save schema for SampleBuilder --- pyhealth/datasets/sample_dataset.py | 178 ++++++---------------------- 1 file changed, 36 insertions(+), 142 deletions(-) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index bd90bc5d7..978489c27 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -1,8 +1,10 @@ -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union, Type +import pickle +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type, override import inspect from torch.utils.data import Dataset from tqdm import tqdm +from litdata import StreamingDataset from ..processors import get_processor from ..processors.base_processor import FeatureProcessor @@ -128,13 +130,13 @@ def fit( self._fitted = True - def transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample dictionary using the fitted processors.""" + def transform(self, sample: dict[str, bytes]) -> Dict[str, Any]: + """Transform a pickled sample using the fitted processors.""" if not self._fitted: raise RuntimeError("SampleBuilder.fit must be called before transform().") transformed: Dict[str, Any] = {} - for key, value in sample.items(): + for key, value in pickle.loads(sample["sample"]).items(): if key in self._input_processors: transformed[key] = self._input_processors[key].process(value) elif key in self._output_processors: @@ -143,8 +145,21 @@ def transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: transformed[key] = value return transformed - -class SampleDataset(Dataset): + def save(self, path: str) -> None: + """Saves the fitted metadata to the specified path.""" + if not self._fitted: + raise RuntimeError("SampleBuilder.fit must be called before save().") + metadata = { + "input_processors": self._input_processors, + "output_processors": self._output_processors, + "patient_to_index": self._patient_to_index, + "record_to_index": self._record_to_index, + } + with open(path, "wb") as f: + pickle.dump(metadata, f) + + +class SampleDataset(StreamingDataset): """Sample dataset class for handling and processing data samples. Attributes: @@ -159,29 +174,14 @@ class SampleDataset(Dataset): def __init__( self, - samples: List[Dict], - input_schema: Dict[ - str, - Union[ - str, - Type[FeatureProcessor], - FeatureProcessor, - Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]], - ], - ], - output_schema: Dict[ - str, - Union[ - str, - Type[FeatureProcessor], - FeatureProcessor, - Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]], - ], - ], - dataset_name: Optional[str] = None, - task_name: Optional[str] = None, + path: str, + input_schema: Dict[str, Any], + output_schema: Dict[str, Any], input_processors: Optional[Dict[str, FeatureProcessor]] = None, output_processors: Optional[Dict[str, FeatureProcessor]] = None, + dataset_name: Optional[str] = None, + task_name: Optional[str] = None, + **kwargs, ) -> None: """Initializes the SampleDataset with samples and schemas. @@ -204,122 +204,24 @@ def __init__( will be used instead of creating new ones from output_schema. Defaults to None. """ + super().__init__(path, **kwargs) if dataset_name is None: dataset_name = "" if task_name is None: task_name = "" - self.samples = samples + + self.dataset_name = "" if dataset_name is None else dataset_name + self.task_name = "" if task_name is None else task_name + self.input_schema = input_schema self.output_schema = output_schema - self.input_processors = input_processors if input_processors is not None else {} - self.output_processors = ( - output_processors if output_processors is not None else {} - ) - self.dataset_name = dataset_name - self.task_name = task_name - # Create patient_to_index and record_to_index mappings + self.input_processors = input_processors + self.output_processors = output_processors + self.patient_to_index = {} self.record_to_index = {} - for i, sample in enumerate(samples): - # Create patient_to_index mapping - patient_id = sample.get("patient_id") - if patient_id is not None: - if patient_id not in self.patient_to_index: - self.patient_to_index[patient_id] = [] - self.patient_to_index[patient_id].append(i) - - # Create record_to_index mapping (optional) - record_id = sample.get("record_id", sample.get("visit_id")) - if record_id is not None: - if record_id not in self.record_to_index: - self.record_to_index[record_id] = [] - self.record_to_index[record_id].append(i) - - self.validate() - self.build() - - def _get_processor_instance(self, processor_spec): - """Get processor instance from either string alias, class reference, processor instance, or tuple with kwargs. - - Args: - processor_spec: Either a string alias, a processor class, a processor instance, or a tuple (spec, kwargs_dict) - - Returns: - Instance of the processor - """ - if isinstance(processor_spec, tuple): - spec, kwargs = processor_spec - if isinstance(spec, str): - return get_processor(spec)(**kwargs) - elif inspect.isclass(spec) and issubclass(spec, FeatureProcessor): - return spec(**kwargs) - else: - raise ValueError( - f"Processor spec in tuple must be either a string alias or a " - f"FeatureProcessor class, got {type(spec)}" - ) - if isinstance(processor_spec, str): - # Use existing registry system for string aliases - return get_processor(processor_spec)() - elif inspect.isclass(processor_spec) and issubclass( - processor_spec, FeatureProcessor - ): - # Direct class reference - return processor_spec() - elif isinstance(processor_spec, FeatureProcessor): - # Already an instance - return processor_spec - else: - raise ValueError( - f"Processor spec must be either a string alias, a " - f"FeatureProcessor class, or a tuple (spec, kwargs_dict), got {type(processor_spec)}" - ) - - def validate(self) -> None: - """Validates that the samples match the input and output schemas.""" - input_keys = set(self.input_schema.keys()) - output_keys = set(self.output_schema.keys()) - for s in self.samples: - assert input_keys.issubset(s.keys()), "Input schema does not match samples." - assert output_keys.issubset( - s.keys() - ), "Output schema does not match samples." - return - - def build(self) -> None: - """Builds the processors for input and output data based on schemas.""" - # Only fit if processors weren't provided - if not self.input_processors: - for k, v in self.input_schema.items(): - self.input_processors[k] = self._get_processor_instance(v) - self.input_processors[k].fit(self.samples, k) - if not self.output_processors: - for k, v in self.output_schema.items(): - self.output_processors[k] = self._get_processor_instance(v) - self.output_processors[k].fit(self.samples, k) - # Always process samples with the (fitted) processors - for sample in tqdm(self.samples, desc="Processing samples"): - for k, v in sample.items(): - if k in self.input_processors: - sample[k] = self.input_processors[k].process(v) - elif k in self.output_processors: - sample[k] = self.output_processors[k].process(v) - return - - def __getitem__(self, index: int) -> Dict: - """Returns a sample by index. - - Args: - index (int): Index of the sample to retrieve. - - Returns: - Dict: A dict with patient_id, visit_id/record_id, and other - task-specific attributes as key. Conversion to index/tensor - will be done in the model. - """ - return self.samples[index] - + @override def __str__(self) -> str: """Returns a string representation of the dataset. @@ -327,11 +229,3 @@ def __str__(self) -> str: str: A string with the dataset and task names. """ return f"Sample dataset {self.dataset_name} {self.task_name}" - - def __len__(self) -> int: - """Returns the number of samples in the dataset. - - Returns: - int: The number of samples. - """ - return len(self.samples) From 8f027aef312377ac89432e5be663120758675b32 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 6 Dec 2025 11:22:24 -0500 Subject: [PATCH 21/82] Fix sampledataset --- examples/memtest.py | 113 +++++++++++++++++++++++----- pyhealth/datasets/base_dataset.py | 18 ++--- pyhealth/datasets/sample_dataset.py | 25 +++--- 3 files changed, 110 insertions(+), 46 deletions(-) diff --git a/examples/memtest.py b/examples/memtest.py index 8a63090e8..4403444e9 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -1,32 +1,105 @@ -# %% -import psutil, os, time, threading -PEAK_MEM_USAGE = 0 -SELF_PROC = psutil.Process(os.getpid()) - -def track_mem(): - global PEAK_MEM_USAGE - while True: - m = SELF_PROC.memory_info().rss - if m > PEAK_MEM_USAGE: - PEAK_MEM_USAGE = m - time.sleep(0.1) +""" +Example of using StageNet for mortality prediction on MIMIC-IV. -threading.Thread(target=track_mem, daemon=True).start() -print(f"[MEM] start={PEAK_MEM_USAGE / (1024**3)} GB") +This example demonstrates: +1. Loading MIMIC-IV data +2. Applying the MortalityPredictionStageNetMIMIC4 task +3. Creating a SampleDataset with StageNet processors +4. Training a StageNet model +""" # %% -from pyhealth.datasets import MIMIC4Dataset -DATASET_DIR = "/home/logic/physionet.org/files/mimiciv/3.1" -dataset = MIMIC4Dataset( - ehr_root=DATASET_DIR, +from pyhealth.datasets import ( + MIMIC4Dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models import StageNet +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 +from pyhealth.trainer import Trainer +import torch + +# %% STEP 1: Load MIMIC-IV base dataset +base_dataset = MIMIC4Dataset( + ehr_root="/home/logic/physionet.org/files/mimiciv/3.1/", ehr_tables=[ "patients", "admissions", "diagnoses_icd", "procedures_icd", - "prescriptions", "labevents", ], + dev=True, ) -print(f"[MEM] __init__={PEAK_MEM_USAGE / (1024**3):.3f} GB") + +# %% # STEP 2: Apply StageNet mortality prediction task +sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=4, +) + +print(f"Total samples: {len(sample_dataset)}") +print(f"Input schema: {sample_dataset.input_schema}") +print(f"Output schema: {sample_dataset.output_schema}") + +# %% Inspect a sample +sample = next(iter(sample_dataset)) +print("\nSample structure:") +print(f" Patient ID: {sample['patient_id']}") +print(f"ICD Codes: {sample['icd_codes']}") +print(f" Labs shape: {len(sample['labs'][0])} timesteps") +print(f" Mortality: {sample['mortality']}") + +# # STEP 3: Split dataset +# train_dataset, val_dataset, test_dataset = split_by_patient( +# sample_dataset, [0.8, 0.1, 0.1] +# ) + +# # Create dataloaders +# train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) +# val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) +# test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) + +# # STEP 4: Initialize StageNet model +# model = StageNet( +# dataset=sample_dataset, +# embedding_dim=128, +# chunk_size=128, +# levels=3, +# dropout=0.3, +# ) + +# num_params = sum(p.numel() for p in model.parameters()) +# print(f"\nModel initialized with {num_params} parameters") + +# # STEP 5: Train the model +# trainer = Trainer( +# model=model, +# device="cuda:5", # or "cpu" +# metrics=["pr_auc", "roc_auc", "accuracy", "f1"], +# ) + +# trainer.train( +# train_dataloader=train_loader, +# val_dataloader=val_loader, +# epochs=50, +# monitor="roc_auc", +# optimizer_params={"lr": 1e-5}, +# ) + +# # STEP 6: Evaluate on test set +# results = trainer.evaluate(test_loader) +# print("\nTest Results:") +# for metric, value in results.items(): +# print(f" {metric}: {value:.4f}") + +# # STEP 7: Inspect model predictions +# sample_batch = next(iter(test_loader)) +# with torch.no_grad(): +# output = model(**sample_batch) + +# print("\nSample predictions:") +# print(f" Predicted probabilities: {output['y_prob'][:5]}") +# print(f" True labels: {output['y_true'][:5]}") + # %% diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 1d85aeebe..fad3a8e72 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -559,6 +559,7 @@ def set_task( output_processors=output_processors, ) builder.fit(dataset) + builder.save(str(path / "schema.pkl")) # Apply processors and save final samples to cache_dir logger.info(f"Processing samples and saving to {path}...") @@ -576,15 +577,8 @@ def set_task( ) logger.info(f"Cached processed samples to {path}") - # sample_dataset = SampleDataset( - # samples, - # input_schema=task.input_schema, - # output_schema=task.output_schema, - # dataset_name=self.dataset_name, - # task_name=task, - # input_processors=input_processors, - # output_processors=output_processors, - # ) - - # logger.info(f"Generated {len(samples)} samples for task {task.task_name}") - # return sample_dataset + return SampleDataset( + path=str(path), + dataset_name=self.dataset_name, + task_name=task.task_name, + ) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 978489c27..626910490 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -150,6 +150,8 @@ def save(self, path: str) -> None: if not self._fitted: raise RuntimeError("SampleBuilder.fit must be called before save().") metadata = { + "input_schema": self.input_schema, + "output_schema": self.output_schema, "input_processors": self._input_processors, "output_processors": self._output_processors, "patient_to_index": self._patient_to_index, @@ -175,10 +177,6 @@ class SampleDataset(StreamingDataset): def __init__( self, path: str, - input_schema: Dict[str, Any], - output_schema: Dict[str, Any], - input_processors: Optional[Dict[str, FeatureProcessor]] = None, - output_processors: Optional[Dict[str, FeatureProcessor]] = None, dataset_name: Optional[str] = None, task_name: Optional[str] = None, **kwargs, @@ -205,21 +203,20 @@ def __init__( Defaults to None. """ super().__init__(path, **kwargs) - if dataset_name is None: - dataset_name = "" - if task_name is None: - task_name = "" self.dataset_name = "" if dataset_name is None else dataset_name self.task_name = "" if task_name is None else task_name - self.input_schema = input_schema - self.output_schema = output_schema - self.input_processors = input_processors - self.output_processors = output_processors + with open(f"{path}/schema.pkl", "rb") as f: + metadata = pickle.load(f) + + self.input_schema = metadata["input_schema"] + self.output_schema = metadata["output_schema"] + self.input_processors = metadata["input_processors"] + self.output_processors = metadata["output_processors"] - self.patient_to_index = {} - self.record_to_index = {} + self.patient_to_index = metadata["patient_to_index"] + self.record_to_index = metadata["record_to_index"] @override def __str__(self) -> str: From 95784246d742a3275e932e0744691ec83d9477c5 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 6 Dec 2025 11:48:44 -0500 Subject: [PATCH 22/82] Fix multi-worker crashes --- pyhealth/datasets/base_dataset.py | 12 +++++++----- tests/core/test_streaming_parquet_writer.py | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index fad3a8e72..5937e94fe 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -107,7 +107,11 @@ def scan_file(file_path: str) -> pl.LazyFrame: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") -class StreamingParquetWriter: +def _identity_fn(x): + return x + + +class _ParquetWriter: """ Stream-write rows into a Parquet file in chunked (row-group) fashion. @@ -526,9 +530,7 @@ def set_task( with tempfile.TemporaryDirectory() as tmp_dir: # Create Parquet file with samples logger.info(f"Applying task transformations on data...") - with StreamingParquetWriter( - f"{tmp_dir}/samples.parquet", schema - ) as writer: + with _ParquetWriter(f"{tmp_dir}/samples.parquet", schema) as writer: # TODO: this can be further optimized. patient_ids = ( global_event_df.select("patient_id") @@ -569,7 +571,7 @@ def set_task( transform=builder.transform, ) litdata.optimize( - fn=lambda x: x, + fn=_identity_fn, inputs=litdata.StreamingDataLoader(dataset), output_dir=str(path), chunk_bytes="64MB", diff --git a/tests/core/test_streaming_parquet_writer.py b/tests/core/test_streaming_parquet_writer.py index bda6f15e8..c7c212bd5 100644 --- a/tests/core/test_streaming_parquet_writer.py +++ b/tests/core/test_streaming_parquet_writer.py @@ -5,7 +5,7 @@ import pyarrow as pa import pyarrow.parquet as pq -from pyhealth.datasets.base_dataset import StreamingParquetWriter +from pyhealth.datasets.base_dataset import _ParquetWriter from tests.base import BaseTestCase @@ -31,7 +31,7 @@ def test_append_flush_close_and_context_manager(self): {"id": 4, "value": "d"}, ] - with StreamingParquetWriter( + with _ParquetWriter( self.output_path, self.schema, chunk_size=2 ) as writer: # First two appends trigger an automatic flush due to chunk_size=2. From 830259a2f78470cbeb8890199c6e1e256e336931 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 6 Dec 2025 11:55:19 -0500 Subject: [PATCH 23/82] update test --- tests/core/test_caching.py | 243 +++++++++++-------------------------- 1 file changed, 72 insertions(+), 171 deletions(-) diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index b9b355bce..2ae9bfe3e 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -1,15 +1,15 @@ import unittest import tempfile -import os +import shutil from pathlib import Path from unittest.mock import patch import polars as pl +import torch from tests.base import BaseTestCase from pyhealth.datasets.base_dataset import BaseDataset from pyhealth.tasks.base_task import BaseTask from pyhealth.datasets.sample_dataset import SampleDataset -from pyhealth.data import Patient class MockTask(BaseTask): @@ -19,11 +19,13 @@ def __init__(self, task_name="test_task"): self.task_name = task_name self.input_schema = {"test_attribute": "raw"} self.output_schema = {"test_label": "binary"} + self.call_count = 0 def __call__(self, patient): """Return mock samples based on patient data.""" # Extract patient's test data from the patient's data source patient_data = patient.data_source + self.call_count += 1 samples = [] for row in patient_data.iter_rows(named=True): @@ -40,13 +42,17 @@ def __call__(self, patient): class MockDataset(BaseDataset): """Mock dataset for testing purposes.""" - def __init__(self): - # Initialize without calling parent __init__ to avoid file dependencies - self.dataset_name = "TestDataset" - self.dev = False + def __init__(self, cache_dir: str | Path | None = None): + super().__init__( + root="", + tables=[], + dataset_name="TestDataset", + cache_dir=cache_dir, + dev=False, + ) - # Create realistic test data with patient_id, test_attribute, and test_label - self._collected_global_event_df = pl.DataFrame( + def load_data(self) -> pl.LazyFrame: + return pl.LazyFrame( { "patient_id": ["1", "2", "1", "2"], "event_type": ["test", "test", "test", "test"], @@ -60,25 +66,6 @@ def __init__(self): "test/test_label": [0, 1, 1, 0], } ) - self._unique_patient_ids = ["1", "2"] - - @property - def collected_global_event_df(self): - return self._collected_global_event_df - - @property - def unique_patient_ids(self): - return self._unique_patient_ids - - def iter_patients(self, df=None): - """Mock patient iterator that returns real Patient objects.""" - if df is None: - df = self.collected_global_event_df - - grouped = df.group_by("patient_id") - for patient_id, patient_df in grouped: - patient_id = patient_id[0] - yield Patient(patient_id=patient_id, data_source=patient_df) class TestCachingFunctionality(BaseTestCase): @@ -86,17 +73,19 @@ class TestCachingFunctionality(BaseTestCase): def setUp(self): """Set up test fixtures.""" - self.dataset = MockDataset() + self.temp_dir = Path(tempfile.mkdtemp()) + self.dataset = MockDataset(cache_dir=self.temp_dir) self.task = MockTask() - self.temp_dir = tempfile.mkdtemp() def tearDown(self): """Clean up test fixtures.""" - # Clean up temporary directory - import shutil - shutil.rmtree(self.temp_dir, ignore_errors=True) + def _task_cache_dir(self) -> Path: + cache_dir = self.temp_dir / "task_cache" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + def test_set_task_signature(self): """Test that set_task has the correct method signature.""" import inspect @@ -104,7 +93,15 @@ def test_set_task_signature(self): sig = inspect.signature(BaseDataset.set_task) params = list(sig.parameters.keys()) - expected_params = ["self", "task", "num_workers", "cache_dir", "cache_format", "input_processors", "output_processors"] + expected_params = [ + "self", + "task", + "num_workers", + "cache_dir", + "cache_format", + "input_processors", + "output_processors", + ] self.assertEqual(params, expected_params) # Check default values @@ -115,158 +112,62 @@ def test_set_task_signature(self): self.assertEqual(sig.parameters["input_processors"].default, None) self.assertEqual(sig.parameters["output_processors"].default, None) - def test_set_task_no_caching(self): - """Test set_task without caching (cache_dir=None).""" - sample_dataset = self.dataset.set_task(self.task) + def test_set_task_writes_cache_and_metadata(self): + """Ensure set_task materializes cache files and schema metadata.""" + cache_dir = self._task_cache_dir() + sample_dataset = self.dataset.set_task( + self.task, cache_dir=cache_dir, cache_format="parquet" + ) self.assertIsInstance(sample_dataset, SampleDataset) - self.assertEqual(len(sample_dataset), 4) # Two patients, two samples each self.assertEqual(sample_dataset.dataset_name, "TestDataset") + self.assertEqual(sample_dataset.task_name, self.task.task_name) + self.assertEqual(len(sample_dataset), 4) + self.assertEqual(self.task.call_count, 2) + + # Cache artifacts should be present for StreamingDataset + self.assertTrue((cache_dir / "index.json").exists()) + self.assertTrue((cache_dir / "schema.pkl").exists()) - # Check that samples have the correct structure + # Check processed sample structure and metadata persisted sample = sample_dataset[0] self.assertIn("test_attribute", sample) self.assertIn("test_label", sample) self.assertIn("patient_id", sample) - - def test_full_parquet_caching_cycle(self): - """Test complete save and load cycle with parquet caching.""" - cache_path = Path(self.temp_dir) / f"{self.task.task_name}.parquet" - - # Step 1: First call - should generate samples and save to cache - self.assertFalse(cache_path.exists(), "Cache file should not exist initially") - - sample_dataset_1 = self.dataset.set_task( - self.task, cache_dir=self.temp_dir, cache_format="parquet" - ) - - # Verify cache file was created + self.assertIsInstance(sample["test_label"], torch.Tensor) + self.assertIn("test_attribute", sample_dataset.input_processors) + self.assertIn("test_label", sample_dataset.output_processors) + self.assertEqual(set(sample_dataset.patient_to_index), {"1", "2"}) self.assertTrue( - cache_path.exists(), "Cache file should be created after first call" - ) - - # Verify the sample dataset is correct - self.assertIsInstance(sample_dataset_1, SampleDataset) - self.assertEqual( - len(sample_dataset_1), 4 - ) # Should have 4 samples from our mock data - - # Step 2: Second call - should load from cache (not regenerate) - sample_dataset_2 = self.dataset.set_task( - self.task, cache_dir=self.temp_dir, cache_format="parquet" - ) - - # Verify the loaded dataset matches the original - self.assertIsInstance(sample_dataset_2, SampleDataset) - self.assertEqual(len(sample_dataset_2), 4) - - # Step 3: Verify the actual cached data is correct - # Load the parquet file directly to check its contents - cached_df = pl.read_parquet(cache_path) - cached_samples = cached_df.to_dicts() - - self.assertEqual(len(cached_samples), 4) - - # Verify sample content matches expected structure - for sample in cached_samples: - self.assertIn("test_attribute", sample) - self.assertIn("test_label", sample) - self.assertIn("patient_id", sample) - self.assertIn(sample["patient_id"], ["1", "2"]) - self.assertIn(sample["test_label"], [0, 1]) - - def test_full_pickle_caching_cycle(self): - """Test complete save and load cycle with pickle caching.""" - cache_path = Path(self.temp_dir) / f"{self.task.task_name}.pickle" - - # Step 1: First call - should generate samples and save to cache - self.assertFalse(cache_path.exists(), "Cache file should not exist initially") - - sample_dataset_1 = self.dataset.set_task( - self.task, cache_dir=self.temp_dir, cache_format="pickle" + all(len(indexes) == 2 for indexes in sample_dataset.patient_to_index.values()) ) + self.assertEqual(sample_dataset.record_to_index, {}) - # Verify cache file was created - self.assertTrue( - cache_path.exists(), "Cache file should be created after first call" - ) - - # Verify the sample dataset is correct - self.assertIsInstance(sample_dataset_1, SampleDataset) - self.assertEqual( - len(sample_dataset_1), 4 - ) # Should have 4 samples from our mock data - - # Step 2: Second call - should load from cache (not regenerate) - sample_dataset_2 = self.dataset.set_task( - self.task, cache_dir=self.temp_dir, cache_format="pickle" - ) - - # Verify the loaded dataset matches the original - self.assertIsInstance(sample_dataset_2, SampleDataset) - self.assertEqual(len(sample_dataset_2), 4) - - # Step 3: Verify the actual cached data is correct - # Load the pickle file directly to check its contents - import pickle - - with open(cache_path, "rb") as f: - cached_samples = pickle.load(f) - - self.assertEqual(len(cached_samples), 4) - - # Verify sample content matches expected structure - for sample in cached_samples: - self.assertIn("test_attribute", sample) - self.assertIn("test_label", sample) - self.assertIn("patient_id", sample) - self.assertIn(sample["patient_id"], ["1", "2"]) - self.assertIn(sample["test_label"], [0, 1]) - - def test_set_task_invalid_cache_format(self): - """Test set_task with invalid cache format.""" - # This should not raise an error during set_task call, - # but should log a warning when trying to save - sample_dataset = self.dataset.set_task( - self.task, cache_dir=self.temp_dir, cache_format="invalid_format" - ) - - self.assertIsInstance(sample_dataset, SampleDataset) - self.assertEqual(len(sample_dataset), 4) # Generated samples - - @patch("polars.read_parquet") - def test_set_task_cache_load_failure_fallback(self, mock_read_parquet): - """Test fallback to generation when cache loading fails.""" - # Make read_parquet raise an exception - mock_read_parquet.side_effect = Exception("Failed to read cache") - - # Create a dummy cache file - cache_path = Path(self.temp_dir) / f"{self.task.task_name}.parquet" - cache_path.touch() - - sample_dataset = self.dataset.set_task( - self.task, cache_dir=self.temp_dir, cache_format="parquet" - ) - - # Should still work by falling back to generation - self.assertIsInstance(sample_dataset, SampleDataset) - self.assertEqual(len(sample_dataset), 4) # Generated samples - - def test_cache_directory_creation(self): - """Test that cache directory is created if it doesn't exist.""" - nested_cache_dir = os.path.join(self.temp_dir, "nested", "cache", "dir") - - # Ensure the nested directory doesn't exist - self.assertFalse(os.path.exists(nested_cache_dir)) + def test_default_cache_dir_is_used(self): + """When cache_dir is omitted, default cache dir should be used.""" + task_cache = self.dataset.cache_dir / "tasks" / self.task.task_name + sample_dataset = self.dataset.set_task(self.task) - with patch("polars.DataFrame.write_parquet"): - sample_dataset = self.dataset.set_task( - self.task, cache_dir=nested_cache_dir, cache_format="parquet" + self.assertTrue(task_cache.exists()) + self.assertTrue((task_cache / "index.json").exists()) + self.assertTrue((self.dataset.cache_dir / "global_event_df.parquet").exists()) + self.assertEqual(len(sample_dataset), 4) + + def test_reuses_existing_cache_without_regeneration(self): + """Second call should reuse cached samples instead of recomputing.""" + cache_dir = self._task_cache_dir() + _ = self.dataset.set_task(self.task, cache_dir=cache_dir) + self.assertEqual(self.task.call_count, 2) + + with patch.object( + self.task, "__call__", side_effect=AssertionError("Task should not rerun") + ): + cached_dataset = self.dataset.set_task( + self.task, cache_dir=cache_dir, cache_format="parquet" ) - # Directory should be created - self.assertTrue(os.path.exists(nested_cache_dir)) - self.assertIsInstance(sample_dataset, SampleDataset) + self.assertEqual(len(cached_dataset), 4) + self.assertEqual(self.task.call_count, 2) if __name__ == "__main__": From 3652e2d061cccf730d5d7fcf146b3361dc764b7f Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 6 Dec 2025 12:02:15 -0500 Subject: [PATCH 24/82] Fix non-pickable --- pyhealth/processors/image_processor.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pyhealth/processors/image_processor.py b/pyhealth/processors/image_processor.py index 17c6d5e97..d174529a7 100644 --- a/pyhealth/processors/image_processor.py +++ b/pyhealth/processors/image_processor.py @@ -1,3 +1,4 @@ +from functools import partial from pathlib import Path from typing import Any, List, Optional, Union @@ -59,7 +60,9 @@ def __init__( def _build_transform(self) -> transforms.Compose: transform_list = [] if self.mode is not None: - transform_list.append(transforms.Lambda(lambda img: img.convert(self.mode))) + transform_list.append( + transforms.Lambda(partial(_convert_mode, mode=self.mode)) + ) if self.image_size is not None: transform_list.append( transforms.Resize((self.image_size, self.image_size)) @@ -98,3 +101,8 @@ def __repr__(self) -> str: f"to_tensor={self.to_tensor}, normalize={self.normalize}, " f"mean={self.mean}, std={self.std}, mode={self.mode})" ) + + +def _convert_mode(img: Image.Image, mode: str) -> Image.Image: + """Convert a PIL image to the requested mode.""" + return img.convert(mode) From 174d53f6d9df75fe4e21a9c31a8b3db9c4c82bdd Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 6 Dec 2025 12:49:22 -0500 Subject: [PATCH 25/82] Fix get_dataloader --- pyhealth/datasets/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyhealth/datasets/utils.py b/pyhealth/datasets/utils.py index 63ca4152a..af7babe19 100644 --- a/pyhealth/datasets/utils.py +++ b/pyhealth/datasets/utils.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple import torch +import litdata from dateutil.parser import parse as dateutil_parse from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader @@ -319,7 +320,7 @@ def collate_fn_dict_with_padding(batch: List[dict]) -> dict: def get_dataloader( - dataset: torch.utils.data.Dataset, batch_size: int, shuffle: bool = False + dataset: litdata.StreamingDataset, batch_size: int, shuffle: bool = False ) -> DataLoader: """Creates a DataLoader for a given dataset. @@ -331,10 +332,10 @@ def get_dataloader( Returns: A DataLoader instance for the dataset. """ + dataset.set_shuffle(shuffle) dataloader = DataLoader( dataset, batch_size=batch_size, - shuffle=shuffle, collate_fn=collate_fn_dict_with_padding, ) From 302cb22e47323e7ce6d54b4a2f81327daf9c6f4e Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 6 Dec 2025 12:49:28 -0500 Subject: [PATCH 26/82] Fix embedding --- pyhealth/models/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/models/embedding.py b/pyhealth/models/embedding.py index 088282a66..f5dbc937c 100644 --- a/pyhealth/models/embedding.py +++ b/pyhealth/models/embedding.py @@ -103,7 +103,7 @@ def __init__(self, dataset: SampleDataset, embedding_dim: int = 128): # For tensor processor, we need to determine the input size # from the first sample in the dataset sample_tensor = None - for sample in dataset.samples: + for sample in dataset: if field_name in sample: sample_tensor = processor.process(sample[field_name]) break From b754b4f16e73ed12007c7260a72404083b72acb0 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 6 Dec 2025 13:07:53 -0500 Subject: [PATCH 27/82] support split --- examples/memtest.py | 102 ++++++++++++++-------------- pyhealth/datasets/sample_dataset.py | 70 ++++++++++++++++++- pyhealth/datasets/splitter.py | 42 ++++++------ 3 files changed, 140 insertions(+), 74 deletions(-) diff --git a/examples/memtest.py b/examples/memtest.py index 4403444e9..ed8b96e8f 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -50,56 +50,56 @@ print(f" Labs shape: {len(sample['labs'][0])} timesteps") print(f" Mortality: {sample['mortality']}") -# # STEP 3: Split dataset -# train_dataset, val_dataset, test_dataset = split_by_patient( -# sample_dataset, [0.8, 0.1, 0.1] -# ) - -# # Create dataloaders -# train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) -# val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) -# test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) - -# # STEP 4: Initialize StageNet model -# model = StageNet( -# dataset=sample_dataset, -# embedding_dim=128, -# chunk_size=128, -# levels=3, -# dropout=0.3, -# ) - -# num_params = sum(p.numel() for p in model.parameters()) -# print(f"\nModel initialized with {num_params} parameters") - -# # STEP 5: Train the model -# trainer = Trainer( -# model=model, -# device="cuda:5", # or "cpu" -# metrics=["pr_auc", "roc_auc", "accuracy", "f1"], -# ) - -# trainer.train( -# train_dataloader=train_loader, -# val_dataloader=val_loader, -# epochs=50, -# monitor="roc_auc", -# optimizer_params={"lr": 1e-5}, -# ) - -# # STEP 6: Evaluate on test set -# results = trainer.evaluate(test_loader) -# print("\nTest Results:") -# for metric, value in results.items(): -# print(f" {metric}: {value:.4f}") - -# # STEP 7: Inspect model predictions -# sample_batch = next(iter(test_loader)) -# with torch.no_grad(): -# output = model(**sample_batch) - -# print("\nSample predictions:") -# print(f" Predicted probabilities: {output['y_prob'][:5]}") -# print(f" True labels: {output['y_true'][:5]}") +# %% STEP 3: Split dataset +train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] +) + +# Create dataloaders +train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) +val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) +test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) + +# %% STEP 4: Initialize StageNet model +model = StageNet( + dataset=sample_dataset, + embedding_dim=128, + chunk_size=128, + levels=3, + dropout=0.3, +) + +num_params = sum(p.numel() for p in model.parameters()) +print(f"\nModel initialized with {num_params} parameters") + +# %% STEP 5: Train the model +trainer = Trainer( + model=model, + device="cpu", # or "cpu" + metrics=["pr_auc", "roc_auc", "accuracy", "f1"], +) + +trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=50, + monitor="roc_auc", + optimizer_params={"lr": 1e-5}, +) + +# %% STEP 6: Evaluate on test set +results = trainer.evaluate(test_loader) +print("\nTest Results:") +for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + +# %% STEP 7: Inspect model predictions +sample_batch = next(iter(test_loader)) +with torch.no_grad(): + output = model(**sample_batch) + +print("\nSample predictions:") +print(f" Predicted probabilities: {output['y_prob'][:5]}") +print(f" True labels: {output['y_true'][:5]}") # %% diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 626910490..1be1cc99f 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -1,10 +1,11 @@ +from collections.abc import Sequence import pickle from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type, override import inspect -from torch.utils.data import Dataset -from tqdm import tqdm +from bisect import bisect_right from litdata import StreamingDataset +from litdata.utilities.train_test_split import deepcopy_dataset from ..processors import get_processor from ..processors.base_processor import FeatureProcessor @@ -226,3 +227,68 @@ def __str__(self) -> str: str: A string with the dataset and task names. """ return f"Sample dataset {self.dataset_name} {self.task_name}" + + def subset( + self, indices: Sequence[int] + ) -> "SampleDataset": + """Create a StreamingDataset restricted to the provided indices.""" + + new_dataset = deepcopy_dataset(self) + + if len(new_dataset.subsampled_files) != len(new_dataset.region_of_interest): + raise ValueError( + "The provided dataset has mismatched subsampled_files and region_of_interest lengths." + ) + + dataset_length = sum(end - start for start, end in new_dataset.region_of_interest) + if any(idx < 0 or idx >= dataset_length for idx in indices): + raise ValueError( + f"Subset indices must be in [0, {dataset_length - 1}] for the provided dataset." + ) + + # Build chunk boundaries so we can translate global indices into + # chunk-local (start, end) pairs that litdata understands. + chunk_starts: List[int] = [] + chunk_boundaries: List[Tuple[str, int, int, int, int]] = [] + cursor = 0 + for filename, (roi_start, roi_end) in zip( + new_dataset.subsampled_files, new_dataset.region_of_interest + ): + chunk_len = roi_end - roi_start + if chunk_len <= 0: + continue + chunk_starts.append(cursor) + chunk_boundaries.append( + (filename, roi_start, roi_end, cursor, cursor + chunk_len) + ) + cursor += chunk_len + + new_subsampled_files: List[str] = [] + new_roi: List[Tuple[int, int]] = [] + prev_chunk_idx: Optional[int] = None + + for idx in indices: + chunk_idx = bisect_right(chunk_starts, idx) - 1 + if chunk_idx < 0 or idx >= chunk_boundaries[chunk_idx][4]: + raise ValueError(f"Index {idx} is out of bounds for the dataset.") + + filename, roi_start, _, global_start, _ = chunk_boundaries[chunk_idx] + offset_in_chunk = roi_start + (idx - global_start) + + if ( + new_roi + and prev_chunk_idx == chunk_idx + and offset_in_chunk == new_roi[-1][1] + ): + new_roi[-1] = (new_roi[-1][0], new_roi[-1][1] + 1) + else: + new_subsampled_files.append(filename) + new_roi.append((offset_in_chunk, offset_in_chunk + 1)) + + prev_chunk_idx = chunk_idx + + new_dataset.subsampled_files = new_subsampled_files + new_dataset.region_of_interest = new_roi + new_dataset.reset() + + return new_dataset diff --git a/pyhealth/datasets/splitter.py b/pyhealth/datasets/splitter.py index c70df5660..cbaaca7c0 100644 --- a/pyhealth/datasets/splitter.py +++ b/pyhealth/datasets/splitter.py @@ -40,9 +40,9 @@ def split_by_visit( int(len(dataset) * ratios[0]) : int(len(dataset) * (ratios[0] + ratios[1])) ] test_index = index[int(len(dataset) * (ratios[0] + ratios[1])) :] - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = dataset.subset(train_index) # type: ignore + val_dataset = dataset.subset(val_index) # type: ignore + test_dataset = dataset.subset(test_index) # type: ignore return train_dataset, val_dataset, test_dataset @@ -82,9 +82,9 @@ def split_by_patient( ) val_index = list(chain(*[dataset.patient_to_index[i] for i in val_patient_indx])) test_index = list(chain(*[dataset.patient_to_index[i] for i in test_patient_indx])) - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = dataset.subset(train_index) # type: ignore + val_dataset = dataset.subset(val_index) # type: ignore + test_dataset = dataset.subset(test_index) # type: ignore return train_dataset, val_dataset, test_dataset @@ -119,9 +119,9 @@ def split_by_sample( int(len(dataset) * ratios[0]) : int(len(dataset) * (ratios[0] + ratios[1])) ] test_index = index[int(len(dataset) * (ratios[0] + ratios[1])) :] - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = dataset.subset(train_index) # type: ignore + val_dataset = dataset.subset(val_index) # type: ignore + test_dataset = dataset.subset(test_index) # type: ignore if get_index: return ( @@ -172,10 +172,10 @@ def split_by_visit_conformal( cal_index = index[val_end:cal_end] test_index = index[cal_end:] - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - cal_dataset = torch.utils.data.Subset(dataset, cal_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = dataset.subset(train_index) # type: ignore + val_dataset = dataset.subset(val_index) # type: ignore + cal_dataset = dataset.subset(cal_index) # type: ignore + test_dataset = dataset.subset(test_index) # type: ignore return train_dataset, val_dataset, cal_dataset, test_dataset @@ -227,10 +227,10 @@ def split_by_patient_conformal( cal_index = list(chain(*[dataset.patient_to_index[i] for i in cal_patient_indx])) test_index = list(chain(*[dataset.patient_to_index[i] for i in test_patient_indx])) - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - cal_dataset = torch.utils.data.Subset(dataset, cal_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = dataset.subset(train_index) # type: ignore + val_dataset = dataset.subset(val_index) # type: ignore + cal_dataset = dataset.subset(cal_index) # type: ignore + test_dataset = dataset.subset(test_index) # type: ignore return train_dataset, val_dataset, cal_dataset, test_dataset @@ -285,8 +285,8 @@ def split_by_sample_conformal( torch.tensor(test_index), ) else: - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - cal_dataset = torch.utils.data.Subset(dataset, cal_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = dataset.subset(train_index) # type: ignore + val_dataset = dataset.subset(val_index) # type: ignore + cal_dataset = dataset.subset(cal_index) # type: ignore + test_dataset = dataset.subset(test_index) # type: ignore return train_dataset, val_dataset, cal_dataset, test_dataset From 34ba3f250e4c28722cd51b0117c9c8bd8b0eea7f Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 6 Dec 2025 13:35:45 -0500 Subject: [PATCH 28/82] Fix test --- tests/core/test_sample_builder.py | 43 ++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/tests/core/test_sample_builder.py b/tests/core/test_sample_builder.py index 33cd0bbca..446320ccf 100644 --- a/tests/core/test_sample_builder.py +++ b/tests/core/test_sample_builder.py @@ -1,3 +1,6 @@ +import os +import pickle +import tempfile import unittest from pyhealth.datasets.sample_dataset import SampleBuilder @@ -21,7 +24,7 @@ def test_fit_and_transform(self): with self.assertRaises(RuntimeError): _ = builder.input_processors # Access before fit should fail - builder.fit(iter(self.samples)) + builder.fit(self.samples) self.assertIn("feature", builder.input_processors) self.assertIn("label", builder.output_processors) @@ -29,7 +32,7 @@ def test_fit_and_transform(self): self.assertEqual(builder.patient_to_index["p1"], [0, 1]) self.assertEqual(builder.record_to_index["r3"], [2]) - transformed = builder.transform(self.samples[0]) + transformed = builder.transform({"sample": pickle.dumps(self.samples[0])}) self.assertEqual(transformed["feature"], "a") self.assertEqual(transformed["label"], 1) self.assertEqual(transformed["patient_id"], "p1") @@ -39,7 +42,41 @@ def test_transform_requires_fit(self): input_schema=self.input_schema, output_schema=self.output_schema ) with self.assertRaises(RuntimeError): - builder.transform(self.samples[0]) + builder.transform({"sample": pickle.dumps(self.samples[0])}) + + def test_index_mappings(self): + builder = SampleBuilder( + input_schema=self.input_schema, output_schema=self.output_schema + ) + builder.fit(self.samples) + + expected_patient = {"p1": [0, 1], "p2": [2]} + expected_record = {"r1": [0], "r2": [1], "r3": [2]} + self.assertEqual(builder.patient_to_index, expected_patient) + self.assertEqual(builder.record_to_index, expected_record) + + def test_save_persists_fitted_state(self): + builder = SampleBuilder( + input_schema=self.input_schema, output_schema=self.output_schema + ) + builder.fit(self.samples) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "schema.pkl") + builder.save(path) + + with open(path, "rb") as f: + metadata = pickle.load(f) + + self.assertEqual(metadata["input_schema"], self.input_schema) + self.assertEqual(metadata["output_schema"], self.output_schema) + self.assertEqual(metadata["patient_to_index"], builder.patient_to_index) + self.assertEqual(metadata["record_to_index"], builder.record_to_index) + + feature_processor = metadata["input_processors"]["feature"] + label_processor = metadata["output_processors"]["label"] + self.assertEqual(feature_processor.process("foo"), "foo") + self.assertEqual(label_processor.process(1), 1) if __name__ == "__main__": From 340902bdcfbd662aaa200994ed3264355e4feb45 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 6 Dec 2025 15:06:14 -0500 Subject: [PATCH 29/82] Fix collate_fn --- pyhealth/datasets/base_dataset.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 5937e94fe..af3109009 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -107,8 +107,8 @@ def scan_file(file_path: str) -> pl.LazyFrame: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") -def _identity_fn(x): - return x +def _uncollate(x: list[Any]) -> Any: + return x[0] if isinstance(x, list) and len(x) == 1 else x class _ParquetWriter: @@ -568,11 +568,14 @@ def set_task( dataset = litdata.StreamingDataset( tmp_dir, item_loader=ParquetLoader(), - transform=builder.transform, ) litdata.optimize( - fn=_identity_fn, - inputs=litdata.StreamingDataLoader(dataset), + fn=builder.transform, + inputs=litdata.StreamingDataLoader( + dataset, + batch_size=1, + collate_fn=_uncollate, + ), output_dir=str(path), chunk_bytes="64MB", num_workers=num_workers, From ba348c10f059381105547878ab579b68aed40999 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 6 Dec 2025 15:07:03 -0500 Subject: [PATCH 30/82] fix conflicting cache dir --- tests/core/test_chestxray14.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/core/test_chestxray14.py b/tests/core/test_chestxray14.py index 77454c1ba..ad2907cb9 100644 --- a/tests/core/test_chestxray14.py +++ b/tests/core/test_chestxray14.py @@ -7,6 +7,7 @@ import os import shutil import unittest +import tempfile import numpy as np from PIL import Image @@ -124,19 +125,22 @@ def test_task_given_invalid_disease(self): _ = ChestXray14BinaryClassification(disease="toothache") def test_task_classify_cardiomegaly(self): + cache_dir = tempfile.mkdtemp() task = ChestXray14BinaryClassification(disease="cardiomegaly") - samples = self.dataset.set_task(task) + samples = self.dataset.set_task(task, cache_dir=cache_dir) self.assertEqual(len(samples), 10) self.assertEqual(sum(sample["label"] for sample in samples), 3) def test_task_classify_hernia(self): + cache_dir = tempfile.mkdtemp() task = ChestXray14BinaryClassification(disease="hernia") - samples = self.dataset.set_task(task) + samples = self.dataset.set_task(task, cache_dir=cache_dir) self.assertEqual(len(samples), 10) self.assertEqual(sum(sample["label"] for sample in samples), 6) def test_task_classify_all(self): - samples = self.dataset.set_task() + cache_dir = tempfile.mkdtemp() + samples = self.dataset.set_task(cache_dir=cache_dir) self.assertEqual(len(samples), 10) actual_labels = [sample["labels"].tolist() for sample in samples] From c0fd0cc30189a73d719d84ca0ebb2396481d522d Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 02:22:47 -0500 Subject: [PATCH 31/82] update test --- tests/core/test_support2.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/core/test_support2.py b/tests/core/test_support2.py index 6a21c84f6..303e89c90 100644 --- a/tests/core/test_support2.py +++ b/tests/core/test_support2.py @@ -137,11 +137,10 @@ def test_survival_preprocess_2m(self): sample_dataset = dataset.set_task(task) self.assertIsNotNone(sample_dataset) - self.assertTrue(hasattr(sample_dataset, "samples")) - self.assertEqual(len(sample_dataset.samples), 3) + self.assertEqual(len(sample_dataset), 3) # Check first sample structure - sample = sample_dataset.samples[0] + sample = next(iter(sample_dataset)) required_keys = [ "patient_id", "demographics", @@ -156,7 +155,7 @@ def test_survival_preprocess_2m(self): self.assertIn(key, sample, f"Sample should contain key: {key}") # Verify survival probabilities are in valid range [0, 1] - for s in sample_dataset.samples: + for s in sample_dataset: survival_prob = s["survival_probability"] self.assertIsInstance(survival_prob, torch.Tensor) prob_value = survival_prob.item() @@ -186,10 +185,10 @@ def test_survival_preprocess_6m(self): sample_dataset = dataset.set_task(task) self.assertIsNotNone(sample_dataset) - self.assertEqual(len(sample_dataset.samples), 3) + self.assertEqual(len(sample_dataset), 3) # Verify all samples have valid survival probabilities - for s in sample_dataset.samples: + for s in sample_dataset: survival_prob = s["survival_probability"] self.assertIsInstance(survival_prob, torch.Tensor) prob_value = survival_prob.item() From 0010409f361eaf149daf0f613005d132f4ea36f7 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 02:25:21 -0500 Subject: [PATCH 32/82] add create_sample_dataset to convert list of samples to SampleDataset --- pyhealth/datasets/__init__.py | 2 +- pyhealth/datasets/sample_dataset.py | 39 +++++++++++++++++++++++++++-- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index df2598f02..b47e28801 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -59,7 +59,7 @@ def __init__(self, *args, **kwargs): from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset -from .sample_dataset import SampleBuilder, SampleDataset +from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset from .bmd_hs import BMDHSDataset diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 1be1cc99f..558a19397 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -1,10 +1,12 @@ from collections.abc import Sequence +from pathlib import Path import pickle +import tempfile from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type, override import inspect from bisect import bisect_right -from litdata import StreamingDataset +import litdata from litdata.utilities.train_test_split import deepcopy_dataset from ..processors import get_processor @@ -162,7 +164,7 @@ def save(self, path: str) -> None: pickle.dump(metadata, f) -class SampleDataset(StreamingDataset): +class SampleDataset(litdata.StreamingDataset): """Sample dataset class for handling and processing data samples. Attributes: @@ -292,3 +294,36 @@ def subset( new_dataset.reset() return new_dataset + +def create_sample_dataset( + samples: List[Dict[str, Any]], + input_schema: Dict[str, Any], + output_schema: Dict[str, Any], + dataset_name: Optional[str] = None, + task_name: Optional[str] = None, + input_processors: Optional[Dict[str, FeatureProcessor]] = None, + output_processors: Optional[Dict[str, FeatureProcessor]] = None, +): + path = Path(tempfile.mkdtemp()) + + builder = SampleBuilder( + input_schema=input_schema, # type: ignore + output_schema=output_schema, # type: ignore + input_processors=input_processors, + output_processors=output_processors, + ) + builder.fit(samples) + builder.save(str(path / "schema.pkl")) + litdata.optimize( + fn=builder.transform, + inputs=[{"sample": pickle.dumps(x)} for x in samples], + output_dir=str(path), + chunk_bytes="64MB", + num_workers=0, + ) + + return SampleDataset( + path=str(path), + dataset_name=dataset_name, + task_name=task_name, + ) \ No newline at end of file From a3798f84d8367944df8b4c2b1b79a5673243ee8f Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 02:25:45 -0500 Subject: [PATCH 33/82] test new dataset --- examples/memtest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/memtest.py b/examples/memtest.py index ed8b96e8f..6eb55ede0 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -82,7 +82,7 @@ trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, - epochs=50, + epochs=5, monitor="roc_auc", optimizer_params={"lr": 1e-5}, ) From 53079e83260889a7d37b44e3d0113864093f8feb Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 02:33:55 -0500 Subject: [PATCH 34/82] Update docs --- pyhealth/datasets/sample_dataset.py | 144 ++++++++++++++++++++++------ 1 file changed, 114 insertions(+), 30 deletions(-) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 558a19397..a14906f2b 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -14,7 +14,24 @@ class SampleBuilder: - """Utility to fit processors and transform samples without materializing a Dataset.""" + """Fit feature processors and transform pickled samples without materializing a dataset. + + SampleBuilder is a lightweight helper used to: + - Fit feature processors from provided `input_schema` and `output_schema` on an + iterable of raw Python sample dictionaries. + - Build mappings from patient IDs and record IDs to sample indices. + - Transform pickled sample records into processed feature dictionaries using + the fitted processors. + + Typical usage: + builder = SampleBuilder(input_schema, output_schema) + builder.fit(samples) + builder.save(path) # writes a schema.pkl metadata file + + After saving the schema, `litdata.optimize` can be used with `builder.transform` + to serialize and chunk pickled sample items into a directory that can be + loaded via SampleDataset. + """ def __init__( self, @@ -96,7 +113,24 @@ def fit( self, samples: Iterable[Dict[str, Any]], ) -> None: - """Fit processors and build index mappings from an iterator of samples.""" + """Fit processors and build mapping indices from an iterator of samples. + + Args: + samples: Iterable of sample dictionaries (e.g., python dicts). Each + sample should contain keys covering both the configured + `input_schema` and `output_schema`. These samples are not + required to be pickled; `fit` operates on in-memory dicts. + + Behavior: + - Validates the samples contain all keys specified by the input + and output schemas. + - Builds `patient_to_index` and `record_to_index` mappings by + recording the sample indices associated with `patient_id` and + `record_id`/`visit_id` fields. + - Instantiates and fits input/output processors from the provided + schemas (unless pre-fitted processors were supplied to the + constructor). + """ # Validate the samples input_keys = set(self.input_schema.keys()) output_keys = set(self.output_schema.keys()) @@ -134,7 +168,21 @@ def fit( self._fitted = True def transform(self, sample: dict[str, bytes]) -> Dict[str, Any]: - """Transform a pickled sample using the fitted processors.""" + """Transform a single serialized (pickled) sample using fitted processors. + + Args: + sample: A mapping with a single key `"sample"` whose value is a + pickled Python dictionary (produced by `pickle.dumps`). The + pickled dictionary should mirror the schema that was used to + fit this builder. + + Returns: + A Python dictionary where each key is either an input or output + feature name. Values for keys present in the corresponding fitted + processors have been processed through their FeatureProcessor and + are returned as the output of that processor. Keys not covered by + the input/output processors are returned unchanged. + """ if not self._fitted: raise RuntimeError("SampleBuilder.fit must be called before transform().") @@ -149,7 +197,15 @@ def transform(self, sample: dict[str, bytes]) -> Dict[str, Any]: return transformed def save(self, path: str) -> None: - """Saves the fitted metadata to the specified path.""" + """Save fitted metadata to the given path as a pickled file. + + Args: + path: Location where the builder will write a pickled metadata file + (commonly named `schema.pkl`). The saved metadata contains + the fitted input/output schemas, processors, and index + mappings. This file is read by `SampleDataset` during + construction. + """ if not self._fitted: raise RuntimeError("SampleBuilder.fit must be called before save().") metadata = { @@ -165,16 +221,28 @@ def save(self, path: str) -> None: class SampleDataset(litdata.StreamingDataset): - """Sample dataset class for handling and processing data samples. + """A streaming dataset that loads sample metadata and processors from disk. + + SampleDataset expects the `path` directory to contain a `schema.pkl` + file created by a `SampleBuilder.save(...)` call. The `schema.pkl` must + include the fitted `input_schema`, `output_schema`, `input_processors`, + `output_processors`, `patient_to_index` and `record_to_index` mappings. Attributes: - samples (List[Dict]): List of data samples. - input_schema (Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]]): - Schema for input data. Values can be string aliases, processor classes, processor instances, or tuples of (spec, kwargs_dict). - output_schema (Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]]): - Schema for output data. Values can be string aliases, processor classes, processor instances, or tuples of (spec, kwargs_dict). - dataset_name (Optional[str]): Name of the dataset. - task_name (Optional[str]): Name of the task. + input_schema: The configuration used to instantiate processors for + input features (string aliases or processor specs). + output_schema: The configuration used to instantiate processors for + output features. + input_processors: A mapping of input feature names to fitted + FeatureProcessor instances. + output_processors: A mapping of output feature names to fitted + FeatureProcessor instances. + patient_to_index: Dictionary mapping patient IDs to the list of + sample indices associated with that patient. + record_to_index: Dictionary mapping record/visit IDs to the list of + sample indices associated with that record. + dataset_name: Optional human friendly dataset name. + task_name: Optional human friendly task name. """ def __init__( @@ -184,26 +252,15 @@ def __init__( task_name: Optional[str] = None, **kwargs, ) -> None: - """Initializes the SampleDataset with samples and schemas. + """Initialize a SampleDataset pointing at a directory created by SampleBuilder. Args: - samples (List[Dict]): List of data samples. - input_schema (Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]]): - Schema for input data. Values can be string aliases, processor classes, processor instances, or tuples of (spec, kwargs_dict) for instantiation. - output_schema (Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]]): - Schema for output data. Values can be string aliases, processor classes, processor instances, or tuples of (spec, kwargs_dict) for instantiation. - dataset_name (Optional[str], optional): Name of the dataset. - Defaults to None. - task_name (Optional[str], optional): Name of the task. - Defaults to None. - input_processors (Optional[Dict[str, FeatureProcessor]], - optional): Pre-fitted input processors. If provided, these - will be used instead of creating new ones from input_schema. - Defaults to None. - output_processors (Optional[Dict[str, FeatureProcessor]], - optional): Pre-fitted output processors. If provided, these - will be used instead of creating new ones from output_schema. - Defaults to None. + path: Path to a directory containing a `schema.pkl` produced by + `SampleBuilder.save` and associated pickled sample files. + dataset_name: Optional human-friendly dataset name. + task_name: Optional human-friendly task name. + **kwargs: Extra keyword arguments forwarded to + `litdata.StreamingDataset` (such as streaming options). """ super().__init__(path, **kwargs) @@ -304,6 +361,33 @@ def create_sample_dataset( input_processors: Optional[Dict[str, FeatureProcessor]] = None, output_processors: Optional[Dict[str, FeatureProcessor]] = None, ): + """Convenience helper to create an on-disk SampleDataset from in-memory samples. + + This helper will: + - Create a temporary directory for the dataset output. + - Fit a `SampleBuilder` with the provided schemas and samples. + - Save the fitted `schema.pkl` to the temporary directory. + - Use `litdata.optimize` with `builder.transform` to write serialized + and chunked sample files into the directory. + - Return a `SampleDataset` instance pointed at the temporary directory. + + Args: + samples: A list of Python dictionaries representing raw samples. + input_schema: Schema describing how input keys should be handled. + output_schema: Schema describing how output keys should be handled. + dataset_name: Optional dataset name to attach to the returned + SampleDataset instance. + task_name: Optional task name to attach to the returned SampleDataset + instance. + input_processors: Optional pre-fitted input processors to use instead + of creating new ones from the input_schema. + output_processors: Optional pre-fitted output processors to use + instead of creating new ones from the output_schema. + + Returns: + An instance of `SampleDataset` loaded from the temporary directory + containing the optimized, chunked samples and `schema.pkl` metadata. + """ path = Path(tempfile.mkdtemp()) builder = SampleBuilder( From e36eae3ffca663c2dfc8782a6d4dbf1fb9f179dc Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 02:40:34 -0500 Subject: [PATCH 35/82] Fix test --- tests/core/test_stagenet.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/core/test_stagenet.py b/tests/core/test_stagenet.py index e1a491b49..b3fd4b5a2 100644 --- a/tests/core/test_stagenet.py +++ b/tests/core/test_stagenet.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import StageNet @@ -61,7 +61,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -223,7 +223,7 @@ def test_time_handling_with_none(self): }, ] - dataset_no_time = SampleDataset( + dataset_no_time = create_sample_dataset( samples=samples_no_time, input_schema={"codes": "stagenet"}, output_schema={"label": "binary"}, @@ -311,7 +311,7 @@ def test_single_feature_input(self): }, ] - dataset_single = SampleDataset( + dataset_single = create_sample_dataset( samples=samples_single, input_schema={"codes": "stagenet"}, output_schema={"label": "binary"}, @@ -451,7 +451,7 @@ def test_processor_padding_with_model_integration(self): # Create dataset - it will use default processor, but we can verify padding works # by checking the processor's configuration - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples_padding, input_schema={"procedures": "stagenet"}, output_schema={"label": "binary"}, From 2bf734276497e0c42c72932139a05db0eb5f2648 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 18:10:13 -0500 Subject: [PATCH 36/82] Fix test --- pyhealth/models/cnn.py | 2 +- tests/core/test_cnn.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyhealth/models/cnn.py b/pyhealth/models/cnn.py index 74d5fd522..dc9853d08 100644 --- a/pyhealth/models/cnn.py +++ b/pyhealth/models/cnn.py @@ -262,7 +262,7 @@ def _determine_input_channels(self, feature_key: str, spatial_dim: int) -> int: return self.embedding_dim min_dim = spatial_dim + 1 - for sample in self.dataset.samples: + for sample in self.dataset: if feature_key not in sample: continue feature = self._extract_feature_tensor(sample[feature_key]) diff --git a/tests/core/test_cnn.py b/tests/core/test_cnn.py index 4bf0bfe89..4a53e5444 100644 --- a/tests/core/test_cnn.py +++ b/tests/core/test_cnn.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import CNN @@ -33,7 +33,7 @@ def setUp(self): } self.output_schema = {"label": "binary"} - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -147,7 +147,7 @@ def test_model_with_image_input(self): input_schema = {"image": "image"} output_schema = {"label": "binary"} - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, output_schema=output_schema, @@ -209,7 +209,7 @@ def test_model_with_mixed_inputs(self): } output_schema = {"label": "binary"} - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, output_schema=output_schema, From 091a22d1ff3c084600bf0f2b6d2d964c546ea345 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 18:14:57 -0500 Subject: [PATCH 37/82] Fix test --- pyhealth/models/contrawr.py | 2 +- tests/core/test_contrawr.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyhealth/models/contrawr.py b/pyhealth/models/contrawr.py index 227e8b1b5..6caa139d3 100644 --- a/pyhealth/models/contrawr.py +++ b/pyhealth/models/contrawr.py @@ -173,7 +173,7 @@ def __init__( self.fc = nn.Linear(emb_size, output_size) def _determine_input_channels_length(self) -> int: - for sample in self.dataset.samples: + for sample in self.dataset: if self.feature_keys[0] not in sample: continue diff --git a/tests/core/test_contrawr.py b/tests/core/test_contrawr.py index 94ad0f8da..25f1bb6f8 100644 --- a/tests/core/test_contrawr.py +++ b/tests/core/test_contrawr.py @@ -2,7 +2,7 @@ import numpy as np import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import ContraWR @@ -33,7 +33,7 @@ def setUp(self): self.input_schema = {"epoch_signal": "tensor"} self.output_schema = {"label": "multiclass"} - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, From f858748c7f518fa252626de7f4fbf0991d4b58a0 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 18:20:49 -0500 Subject: [PATCH 38/82] Fix test --- tests/core/test_early_stopping.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/core/test_early_stopping.py b/tests/core/test_early_stopping.py index 16e9afe75..7a2540e2f 100644 --- a/tests/core/test_early_stopping.py +++ b/tests/core/test_early_stopping.py @@ -1,6 +1,6 @@ import unittest -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import MLP from pyhealth.trainer import Trainer @@ -29,14 +29,14 @@ def setUp(self): self.output_schema = {"label": "binary"} # Split into train and val - self.train_dataset = SampleDataset( + self.train_dataset = create_sample_dataset( samples=self.samples[:80], input_schema=self.input_schema, output_schema=self.output_schema, dataset_name="train", ) - self.val_dataset = SampleDataset( + self.val_dataset = create_sample_dataset( samples=self.samples[80:], input_schema=self.input_schema, output_schema=self.output_schema, From eae72125d8454868c42910969de67053f3f2175d Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 18:23:19 -0500 Subject: [PATCH 39/82] Fix test --- pyhealth/models/gamenet.py | 2 +- tests/core/test_gamenet.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyhealth/models/gamenet.py b/pyhealth/models/gamenet.py index feacb59f9..dd5e27314 100644 --- a/pyhealth/models/gamenet.py +++ b/pyhealth/models/gamenet.py @@ -360,7 +360,7 @@ def generate_ehr_adj(self) -> torch.tensor: label_vocab = self.dataset.output_processors[self.label_key].label_vocab label_size = len(label_vocab) ehr_adj = torch.zeros((label_size, label_size)) - for sample in self.dataset.samples: + for sample in self.dataset: curr_drugs = sample["drugs"] if isinstance(curr_drugs, torch.Tensor): continue diff --git a/tests/core/test_gamenet.py b/tests/core/test_gamenet.py index 1bee65d8f..f456189c7 100644 --- a/tests/core/test_gamenet.py +++ b/tests/core/test_gamenet.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import GAMENet @@ -38,7 +38,7 @@ def setUp(self): } self.output_schema = {"drugs": "multilabel"} - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, From a4b176fbf17ce7792d997c98445af4dcd38c4321 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 18:32:36 -0500 Subject: [PATCH 40/82] add InMemorySampleDataset --- pyhealth/datasets/sample_dataset.py | 116 ++++++++++++++++++++++++++-- 1 file changed, 110 insertions(+), 6 deletions(-) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index a14906f2b..56308febd 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -287,9 +287,7 @@ def __str__(self) -> str: """ return f"Sample dataset {self.dataset_name} {self.task_name}" - def subset( - self, indices: Sequence[int] - ) -> "SampleDataset": + def subset(self, indices: Sequence[int]) -> "SampleDataset": """Create a StreamingDataset restricted to the provided indices.""" new_dataset = deepcopy_dataset(self) @@ -299,7 +297,9 @@ def subset( "The provided dataset has mismatched subsampled_files and region_of_interest lengths." ) - dataset_length = sum(end - start for start, end in new_dataset.region_of_interest) + dataset_length = sum( + end - start for start, end in new_dataset.region_of_interest + ) if any(idx < 0 or idx >= dataset_length for idx in indices): raise ValueError( f"Subset indices must be in [0, {dataset_length - 1}] for the provided dataset." @@ -351,7 +351,111 @@ def subset( new_dataset.reset() return new_dataset - + + +class InMemorySampleDataset(SampleDataset): + """A SampleDataset that loads all samples into memory for fast access. + + InMemorySampleDataset extends SampleDataset by eagerly loading and + transforming all samples into memory during initialization. This allows + for fast, repeated access to samples without disk I/O, at the cost of + higher memory usage. + + Note: + This class is intended for testing and debugging purposes where + dataset sizes are small enough to fit into memory. + """ + + def __init__( + self, + samples: List[Dict[str, Any]], + input_schema: Dict[str, Any], + output_schema: Dict[str, Any], + dataset_name: Optional[str] = None, + task_name: Optional[str] = None, + input_processors: Optional[Dict[str, FeatureProcessor]] = None, + output_processors: Optional[Dict[str, FeatureProcessor]] = None, + ) -> None: + """Initialize an InMemorySampleDataset from in-memory samples. + + This constructor fits a SampleBuilder on the provided samples, + transforms all samples into memory, and sets up the dataset attributes. + + Args: + samples: A list of sample dictionaries (in-memory). + input_schema: Schema describing how input keys should be handled. + output_schema: Schema describing how output keys should be handled. + dataset_name: Optional human-friendly dataset name. + task_name: Optional human-friendly task name. + input_processors: Optional pre-fitted input processors to use instead + of creating new ones from the input_schema. + output_processors: Optional pre-fitted output processors to use + instead of creating new ones from the output_schema. + """ + builder = SampleBuilder( + input_schema=input_schema, + output_schema=output_schema, + input_processors=input_processors, + output_processors=output_processors, + ) + builder.fit(samples) + + self.dataset_name = "" if dataset_name is None else dataset_name + self.task_name = "" if task_name is None else task_name + + self.input_schema = builder.input_schema + self.output_schema = builder.output_schema + self.input_processors = builder.input_processors + self.output_processors = builder.output_processors + + self.patient_to_index = builder.patient_to_index + self.record_to_index = builder.record_to_index + + self._data = [builder.transform({"sample": pickle.dumps(s)}) for s in samples] + + @override + def __len__(self) -> int: + """Returns the number of samples in the dataset. + + Returns: + int: The total number of samples. + """ + return len(self._data) + + @override + def __getitem__(self, index: int) -> Dict[str, Any]: # type: ignore + """Retrieve a processed sample by index. + + Args: + index: The index of the sample to retrieve. + + Returns: + A dictionary containing processed input and output features. + """ + return self._data[index] + + @override + def __iter__(self) -> Iterable[Dict[str, Any]]: # type: ignore + """Returns an iterator over all samples in the dataset. + + Returns: + An iterator yielding processed sample dictionaries. + """ + return iter(self._data) + + @override + def subset(self, indices: Sequence[int]) -> SampleDataset: + return InMemorySampleDataset( + samples=[self._data[i] for i in indices], + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name=self.dataset_name, + task_name=self.task_name, + input_processors=self.input_processors, + output_processors=self.output_processors, + ) + + def create_sample_dataset( samples: List[Dict[str, Any]], input_schema: Dict[str, Any], @@ -410,4 +514,4 @@ def create_sample_dataset( path=str(path), dataset_name=dataset_name, task_name=task_name, - ) \ No newline at end of file + ) From 316b09c8b47f3ef3a3309700295b288ac1ceb221 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 18:40:12 -0500 Subject: [PATCH 41/82] support InMemorySampleDataset --- pyhealth/datasets/sample_dataset.py | 75 ++++++++++++++-------- tests/core/test_sample_dataset.py | 97 +++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 26 deletions(-) create mode 100644 tests/core/test_sample_dataset.py diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 56308febd..84b0e34c5 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -287,7 +287,7 @@ def __str__(self) -> str: """ return f"Sample dataset {self.dataset_name} {self.task_name}" - def subset(self, indices: Sequence[int]) -> "SampleDataset": + def subset(self, indices: Union[Sequence[int], slice]) -> "SampleDataset": """Create a StreamingDataset restricted to the provided indices.""" new_dataset = deepcopy_dataset(self) @@ -300,6 +300,10 @@ def subset(self, indices: Sequence[int]) -> "SampleDataset": dataset_length = sum( end - start for start, end in new_dataset.region_of_interest ) + + if isinstance(indices, slice): + indices = range(*indices.indices(dataset_length)) + if any(idx < 0 or idx >= dataset_length for idx in indices): raise ValueError( f"Subset indices must be in [0, {dataset_length - 1}] for the provided dataset." @@ -444,9 +448,14 @@ def __iter__(self) -> Iterable[Dict[str, Any]]: # type: ignore return iter(self._data) @override - def subset(self, indices: Sequence[int]) -> SampleDataset: + def subset(self, indices: Union[Sequence[int], slice]) -> SampleDataset: + if isinstance(indices, slice): + samples = self._data[indices] + else: + samples = [self._data[i] for i in indices] + return InMemorySampleDataset( - samples=[self._data[i] for i in indices], + samples=samples, input_schema=self.input_schema, output_schema=self.output_schema, dataset_name=self.dataset_name, @@ -464,6 +473,7 @@ def create_sample_dataset( task_name: Optional[str] = None, input_processors: Optional[Dict[str, FeatureProcessor]] = None, output_processors: Optional[Dict[str, FeatureProcessor]] = None, + in_memory: bool = True, ): """Convenience helper to create an on-disk SampleDataset from in-memory samples. @@ -487,31 +497,44 @@ def create_sample_dataset( of creating new ones from the input_schema. output_processors: Optional pre-fitted output processors to use instead of creating new ones from the output_schema. + in_memory: If True, returns an InMemorySampleDataset instead of + a disk-backed SampleDataset. Returns: An instance of `SampleDataset` loaded from the temporary directory containing the optimized, chunked samples and `schema.pkl` metadata. """ - path = Path(tempfile.mkdtemp()) - - builder = SampleBuilder( - input_schema=input_schema, # type: ignore - output_schema=output_schema, # type: ignore - input_processors=input_processors, - output_processors=output_processors, - ) - builder.fit(samples) - builder.save(str(path / "schema.pkl")) - litdata.optimize( - fn=builder.transform, - inputs=[{"sample": pickle.dumps(x)} for x in samples], - output_dir=str(path), - chunk_bytes="64MB", - num_workers=0, - ) - - return SampleDataset( - path=str(path), - dataset_name=dataset_name, - task_name=task_name, - ) + if in_memory: + return InMemorySampleDataset( + samples=samples, + input_schema=input_schema, + output_schema=output_schema, + dataset_name=dataset_name, + task_name=task_name, + input_processors=input_processors, + output_processors=output_processors, + ) + else: + path = Path(tempfile.mkdtemp()) + + builder = SampleBuilder( + input_schema=input_schema, # type: ignore + output_schema=output_schema, # type: ignore + input_processors=input_processors, + output_processors=output_processors, + ) + builder.fit(samples) + builder.save(str(path / "schema.pkl")) + litdata.optimize( + fn=builder.transform, + inputs=[{"sample": pickle.dumps(x)} for x in samples], + output_dir=str(path), + chunk_bytes="64MB", + num_workers=0, + ) + + return SampleDataset( + path=str(path), + dataset_name=dataset_name, + task_name=task_name, + ) diff --git a/tests/core/test_sample_dataset.py b/tests/core/test_sample_dataset.py new file mode 100644 index 000000000..842e6a4e7 --- /dev/null +++ b/tests/core/test_sample_dataset.py @@ -0,0 +1,97 @@ +import unittest +import pickle +from pyhealth.datasets.sample_dataset import create_sample_dataset + +class TestSampleDataset(unittest.TestCase): + def setUp(self): + self.samples = [ + {"patient_id": "p1", "record_id": "r1", "feature": "a", "label": 1}, + {"patient_id": "p1", "record_id": "r2", "feature": "b", "label": 0}, + {"patient_id": "p2", "record_id": "r3", "feature": "c", "label": 1}, + {"patient_id": "p3", "record_id": "r4", "feature": "d", "label": 0}, + {"patient_id": "p3", "record_id": "r5", "feature": "e", "label": 1}, + ] + self.input_schema = {"feature": "raw"} + self.output_schema = {"label": "raw"} + + def test_sample_dataset_subset_slice(self): + # Create SampleDataset (disk-based) + dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + in_memory=False + ) + + # Define a slice object + s = slice(1, 4) # Slice [1:4] -> 1, 2, 3 + + subset = dataset.subset(s) + + self.assertEqual(len(subset), 3) + + # Check content + subset_data = list(subset) + self.assertEqual(subset_data[0]["feature"], "b") + self.assertEqual(subset_data[1]["feature"], "c") + self.assertEqual(subset_data[2]["feature"], "d") + + def test_in_memory_sample_dataset_behavior(self): + # Create both datasets + ds_disk = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + in_memory=False + ) + + ds_mem = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + in_memory=True + ) + + # 1. Test len + self.assertEqual(len(ds_disk), len(ds_mem)) + self.assertEqual(len(ds_disk), 5) + + # 2. Test iter + iter_disk = list(ds_disk) + iter_mem = list(ds_mem) + + for d, m in zip(iter_disk, iter_mem): + self.assertEqual(d["feature"], m["feature"]) + self.assertEqual(d["label"], m["label"]) + self.assertEqual(d["patient_id"], m["patient_id"]) + self.assertEqual(d["record_id"], m["record_id"]) + + # 3. Test getitem + for i in range(len(ds_disk)): + d = ds_disk[i] + m = ds_mem[i] + self.assertEqual(d["feature"], m["feature"]) + self.assertEqual(d["label"], m["label"]) + + # 4. Test subset with list + indices = [0, 2, 4] + sub_disk = ds_disk.subset(indices) + sub_mem = ds_mem.subset(indices) + + self.assertEqual(len(sub_disk), len(sub_mem)) + + for d, m in zip(sub_disk, sub_mem): + self.assertEqual(d["feature"], m["feature"]) + self.assertEqual(d["label"], m["label"]) + + # 5. Test subset with slice + s = slice(0, 3) + sub_disk_slice = ds_disk.subset(s) + sub_mem_slice = ds_mem.subset(s) + + self.assertEqual(len(sub_disk_slice), len(sub_mem_slice)) + for d, m in zip(sub_disk_slice, sub_mem_slice): + self.assertEqual(d["feature"], m["feature"]) + +if __name__ == "__main__": + unittest.main() From d091341bcc79b5d609580cf77064a6939e9daf7f Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 18:43:00 -0500 Subject: [PATCH 42/82] Fix test --- tests/core/test_integrated_gradients.py | 6 +++--- tests/core/test_interp_metrics.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/core/test_integrated_gradients.py b/tests/core/test_integrated_gradients.py index 82365b235..8cbea6600 100644 --- a/tests/core/test_integrated_gradients.py +++ b/tests/core/test_integrated_gradients.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import MLP, StageNet from pyhealth.interpret.methods import IntegratedGradients @@ -43,7 +43,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -249,7 +249,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_interp_metrics.py b/tests/core/test_interp_metrics.py index 01a22184b..9c46cf4ee 100644 --- a/tests/core/test_interp_metrics.py +++ b/tests/core/test_interp_metrics.py @@ -11,7 +11,7 @@ import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.interpret.methods import IntegratedGradients from pyhealth.metrics.interpretability import ( ComprehensivenessMetric, @@ -99,7 +99,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, From fe5a7693830797f71fdf3336fd7b5702d09717f6 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 18:47:19 -0500 Subject: [PATCH 43/82] Fix tests --- tests/core/test_gnn.py | 6 ++-- tests/core/test_processor_transfer.py | 42 +++++++++++++-------------- tests/core/test_safedrug.py | 4 +-- tests/core/test_transformer.py | 4 +-- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/core/test_gnn.py b/tests/core/test_gnn.py index 97cdbce5c..ef1468ca9 100644 --- a/tests/core/test_gnn.py +++ b/tests/core/test_gnn.py @@ -3,7 +3,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import GCN, GAT @@ -37,7 +37,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # binary classification # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -258,7 +258,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # binary classification # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_processor_transfer.py b/tests/core/test_processor_transfer.py index d7bb41627..07ed867b1 100644 --- a/tests/core/test_processor_transfer.py +++ b/tests/core/test_processor_transfer.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import MLP @@ -63,7 +63,7 @@ def setUp(self): def test_basic_processor_transfer(self): """Test basic processor transfer from train to test dataset.""" # Create training dataset (processors will be fitted) - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=self.train_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -71,7 +71,7 @@ def test_basic_processor_transfer(self): ) # Create test dataset with transferred processors - test_dataset = SampleDataset( + test_dataset = create_sample_dataset( samples=self.test_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -97,7 +97,7 @@ def test_basic_processor_transfer(self): def test_processor_vocabulary_consistency(self): """Test that transferred processors maintain vocabulary consistency.""" # Create training dataset - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=self.train_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -109,7 +109,7 @@ def test_processor_vocabulary_consistency(self): train_vocab = train_dataset.input_processors["conditions"].code_vocab # Create test dataset with transferred processors - test_dataset = SampleDataset( + test_dataset = create_sample_dataset( samples=self.test_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -135,7 +135,7 @@ def test_processor_vocabulary_consistency(self): def test_model_training_with_transferred_processors(self): """Test end-to-end training and inference with processor transfer.""" # Create training dataset - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=self.train_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -143,7 +143,7 @@ def test_model_training_with_transferred_processors(self): ) # Create test dataset with transferred processors - test_dataset = SampleDataset( + test_dataset = create_sample_dataset( samples=self.test_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -176,7 +176,7 @@ def test_model_training_with_transferred_processors(self): def test_without_processor_transfer(self): """Test that without transfer, each dataset fits its own processors.""" # Create training dataset - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=self.train_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -184,7 +184,7 @@ def test_without_processor_transfer(self): ) # Create test dataset WITHOUT transferred processors - test_dataset = SampleDataset( + test_dataset = create_sample_dataset( samples=self.test_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -210,7 +210,7 @@ def test_without_processor_transfer(self): def test_partial_processor_transfer(self): """Test transferring only input processors, not output processors.""" # Create training dataset - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=self.train_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -218,7 +218,7 @@ def test_partial_processor_transfer(self): ) # Create test dataset with only input processors transferred - test_dataset = SampleDataset( + test_dataset = create_sample_dataset( samples=self.test_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -242,7 +242,7 @@ def test_partial_processor_transfer(self): def test_empty_processor_dict_transfer(self): """Test passing empty processor dictionaries.""" # Create dataset with empty processor dicts (should fit new ones) - dataset = SampleDataset( + dataset = create_sample_dataset( samples=self.train_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -272,7 +272,7 @@ def test_cross_validation_scenario(self): ] # has 1 and 0 # Fold 1 as train - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=fold1, input_schema=self.input_schema, output_schema=self.output_schema, @@ -280,7 +280,7 @@ def test_cross_validation_scenario(self): ) # Fold 2 as validation with transferred processors - val_dataset = SampleDataset( + val_dataset = create_sample_dataset( samples=fold2, input_schema=self.input_schema, output_schema=self.output_schema, @@ -343,7 +343,7 @@ def test_multimodal_processor_transfer(self): } # Create train dataset - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=multi_train, input_schema=multi_input_schema, output_schema=self.output_schema, @@ -351,7 +351,7 @@ def test_multimodal_processor_transfer(self): ) # Create test dataset with transferred processors - test_dataset = SampleDataset( + test_dataset = create_sample_dataset( samples=multi_test, input_schema=multi_input_schema, output_schema=self.output_schema, @@ -377,13 +377,13 @@ def test_multimodal_processor_transfer(self): def test_backward_compatibility(self): """Test that existing code without processor transfer still works.""" # Old-style usage without any processor parameters - dataset1 = SampleDataset( + dataset1 = create_sample_dataset( samples=self.train_samples, input_schema=self.input_schema, output_schema=self.output_schema, ) - dataset2 = SampleDataset( + dataset2 = create_sample_dataset( samples=self.test_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -428,7 +428,7 @@ def test_none_processors_explicitly(self): output_schema = {"label": "binary"} # Explicitly pass None - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, output_schema=output_schema, @@ -473,14 +473,14 @@ def test_mixed_transfer_and_schema(self): output_schema = {"label": "binary"} # Create train dataset - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=train_samples, input_schema=input_schema, output_schema=output_schema, ) # Create test dataset with same schema - test_dataset = SampleDataset( + test_dataset = create_sample_dataset( samples=test_samples, input_schema=input_schema, output_schema=output_schema, diff --git a/tests/core/test_safedrug.py b/tests/core/test_safedrug.py index 707f164e1..75e278a19 100644 --- a/tests/core/test_safedrug.py +++ b/tests/core/test_safedrug.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import SafeDrug @@ -33,7 +33,7 @@ def setUp(self): } self.output_schema = {"drugs": "multilabel"} - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_transformer.py b/tests/core/test_transformer.py index 25a44495b..a5fa6cc6b 100644 --- a/tests/core/test_transformer.py +++ b/tests/core/test_transformer.py @@ -3,7 +3,7 @@ import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import Transformer from pyhealth.processors.base_processor import FeatureProcessor from pyhealth.interpret.methods import CheferRelevance @@ -42,7 +42,7 @@ def setUp(self): "label": "binary" } - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, From be1e74ae0fd65c1c31331f2d3c5a441d8c47c91b Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 18:50:25 -0500 Subject: [PATCH 44/82] Fix test --- tests/core/test_legacy_mode_resolution.py | 10 +++++----- tests/core/test_logistic_regression.py | 6 +++--- tests/core/test_micron.py | 4 ++-- tests/core/test_mlp.py | 4 ++-- tests/core/test_molerec.py | 4 ++-- tests/core/test_multi_hot_embeddings.py | 4 ++-- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/core/test_legacy_mode_resolution.py b/tests/core/test_legacy_mode_resolution.py index e00644f24..9e411b7b3 100644 --- a/tests/core/test_legacy_mode_resolution.py +++ b/tests/core/test_legacy_mode_resolution.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from pyhealth.datasets.sample_dataset import SampleDataset +from pyhealth.datasets.sample_dataset import create_sample_dataset from pyhealth.models.base_model import BaseModel from pyhealth.processors import ( BinaryLabelProcessor, @@ -38,7 +38,7 @@ def _build_dataset(self, output_processor, key="label"): ] input_schema = {"text": "raw"} output_schema = {key: output_processor} - return SampleDataset(samples, input_schema, output_schema) + return create_sample_dataset(samples, input_schema, output_schema) def test_string_schema_sets_mode(self): ds = self._build_dataset("binary") @@ -66,7 +66,7 @@ def test_unregistered_processor_leaves_mode_none(self): def test_multiclass_loss_selection(self): samples = [{"label": i % 3, "text": f"row{i}"} for i in range(6)] - ds = SampleDataset( + ds = create_sample_dataset( samples, {"text": "raw"}, {"label": MultiClassLabelProcessor} ) model = BaseModel(dataset=ds) @@ -78,7 +78,7 @@ def test_multilabel_loss_selection(self): {"label": [0, 2], "text": "row0"}, {"label": [1], "text": "row1"}, ] - ds = SampleDataset(samples, {"text": "raw"}, {"label": MultiLabelProcessor}) + ds = create_sample_dataset(samples, {"text": "raw"}, {"label": MultiLabelProcessor}) model = BaseModel(dataset=ds) self.assertEqual(model.mode, "multilabel") self.assertEqual( @@ -90,7 +90,7 @@ def test_regression_loss_selection(self): {"label": 0.5, "text": "r0"}, {"label": 1.2, "text": "r1"}, ] - ds = SampleDataset( + ds = create_sample_dataset( samples, {"text": "raw"}, {"label": RegressionLabelProcessor} ) model = BaseModel(dataset=ds) diff --git a/tests/core/test_logistic_regression.py b/tests/core/test_logistic_regression.py index ff9746368..91ff954bb 100644 --- a/tests/core/test_logistic_regression.py +++ b/tests/core/test_logistic_regression.py @@ -2,7 +2,7 @@ import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import LogisticRegression @@ -36,7 +36,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # binary classification # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -164,7 +164,7 @@ def test_regression_task(self): } output_schema = {"score": "regression"} - regression_dataset = SampleDataset( + regression_dataset = create_sample_dataset( samples=regression_samples, input_schema=input_schema, output_schema=output_schema, diff --git a/tests/core/test_micron.py b/tests/core/test_micron.py index 938923c6e..59107f907 100644 --- a/tests/core/test_micron.py +++ b/tests/core/test_micron.py @@ -4,7 +4,7 @@ import torch import numpy as np -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import MICRON from pyhealth.processors.base_processor import FeatureProcessor @@ -47,7 +47,7 @@ def setUp(self): "drugs": "multilabel" } - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_mlp.py b/tests/core/test_mlp.py index 2320b577c..892decb99 100644 --- a/tests/core/test_mlp.py +++ b/tests/core/test_mlp.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import MLP @@ -35,7 +35,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # binary classification # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_molerec.py b/tests/core/test_molerec.py index 1ebbb2a41..c55457044 100644 --- a/tests/core/test_molerec.py +++ b/tests/core/test_molerec.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import MoleRec @@ -39,7 +39,7 @@ def setUp(self): } self.output_schema = {"drugs": "multilabel"} - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_multi_hot_embeddings.py b/tests/core/test_multi_hot_embeddings.py index fbb4bde20..0f2aee258 100644 --- a/tests/core/test_multi_hot_embeddings.py +++ b/tests/core/test_multi_hot_embeddings.py @@ -2,7 +2,7 @@ import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models.embedding import EmbeddingModel from pyhealth.models.mlp import MLP @@ -39,7 +39,7 @@ def setUp(self) -> None: self.input_schema = {"ethnicity": "multi_hot"} self.output_schema = {"label": "binary"} - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, From 8dc1acd4891de952c7159daebe58bc4522f1e446 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 18:51:12 -0500 Subject: [PATCH 45/82] Fix test --- tests/core/test_adacare.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/test_adacare.py b/tests/core/test_adacare.py index 19754cb1d..daf60ecee 100644 --- a/tests/core/test_adacare.py +++ b/tests/core/test_adacare.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import AdaCare @@ -51,7 +51,7 @@ def setUp(self): } self.output_schema = {"label": "binary"} - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, From b1539ca26cbd70b4982a974a40486c9b2c589d3f Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 18:57:08 -0500 Subject: [PATCH 46/82] support set_shuffle for InMemorySampleDataset --- pyhealth/datasets/sample_dataset.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 84b0e34c5..91c4528dd 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -4,7 +4,7 @@ import tempfile from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type, override import inspect - +import random from bisect import bisect_right import litdata from litdata.utilities.train_test_split import deepcopy_dataset @@ -417,6 +417,12 @@ def __init__( self._data = [builder.transform({"sample": pickle.dumps(s)}) for s in samples] + self._shuffle = False + + @override + def set_shuffle(self, shuffle: bool) -> None: + self._shuffle = shuffle + @override def __len__(self) -> int: """Returns the number of samples in the dataset. @@ -445,7 +451,12 @@ def __iter__(self) -> Iterable[Dict[str, Any]]: # type: ignore Returns: An iterator yielding processed sample dictionaries. """ - return iter(self._data) + if self._shuffle: + shuffled_data = self._data[:] + random.shuffle(shuffled_data) + return iter(shuffled_data) + else: + return iter(self._data) @override def subset(self, indices: Union[Sequence[int], slice]) -> SampleDataset: From 7e701a106abd0c49f7dbf6d6483275bb57b8b9f2 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 19:09:48 -0500 Subject: [PATCH 47/82] Fix in memory dataset subset --- pyhealth/datasets/sample_dataset.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 91c4528dd..3cc15d7ca 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -8,6 +8,7 @@ from bisect import bisect_right import litdata from litdata.utilities.train_test_split import deepcopy_dataset +import copy from ..processors import get_processor from ..processors.base_processor import FeatureProcessor @@ -464,16 +465,10 @@ def subset(self, indices: Union[Sequence[int], slice]) -> SampleDataset: samples = self._data[indices] else: samples = [self._data[i] for i in indices] - - return InMemorySampleDataset( - samples=samples, - input_schema=self.input_schema, - output_schema=self.output_schema, - dataset_name=self.dataset_name, - task_name=self.task_name, - input_processors=self.input_processors, - output_processors=self.output_processors, - ) + + new_dataset = copy.deepcopy(self) + new_dataset._data = samples + return new_dataset def create_sample_dataset( From 68a7ea9ba134341741094be2e43e2053ba460355 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 19:09:56 -0500 Subject: [PATCH 48/82] Fix test --- .../predictionset/covariate/covariate_label.py | 4 ++-- tests/core/test_covariate_label.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pyhealth/calib/predictionset/covariate/covariate_label.py b/pyhealth/calib/predictionset/covariate/covariate_label.py index e3aa0e293..602944b15 100644 --- a/pyhealth/calib/predictionset/covariate/covariate_label.py +++ b/pyhealth/calib/predictionset/covariate/covariate_label.py @@ -14,7 +14,7 @@ import numpy as np import torch -from torch.utils.data import Subset +from torch.utils.data import IterableDataset from pyhealth.calib.base_classes import SetPredictor from pyhealth.calib.calibration.kcal.kde import RBFKernelMean @@ -306,7 +306,7 @@ def __init__( def calibrate( self, - cal_dataset: Subset, + cal_dataset: IterableDataset, cal_embeddings: Optional[np.ndarray] = None, test_embeddings: Optional[np.ndarray] = None, ): diff --git a/tests/core/test_covariate_label.py b/tests/core/test_covariate_label.py index e1c8af02b..b5aa0aaf6 100644 --- a/tests/core/test_covariate_label.py +++ b/tests/core/test_covariate_label.py @@ -2,7 +2,7 @@ import numpy as np import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import MLP from pyhealth.calib.predictionset.covariate import CovariateLabel, fit_kde from pyhealth.calib.utils import extract_embeddings @@ -67,7 +67,7 @@ def setUp(self): self.output_schema = {"label": "multiclass"} # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -150,7 +150,7 @@ def test_initialization_non_multiclass_raises_error(self): "label": 1, }, ] - binary_dataset = SampleDataset( + binary_dataset = create_sample_dataset( samples=binary_samples, input_schema={"conditions": "sequence", "procedures": "tensor"}, output_schema={"label": "binary"}, @@ -182,7 +182,7 @@ def test_calibrate_marginal(self): # Calibrate on first 4 samples cal_indices = [0, 1, 2, 3] - cal_dataset = torch.utils.data.Subset(self.dataset, cal_indices) + cal_dataset = self.dataset.subset(cal_indices) # Extract embeddings cal_embeddings = self._get_embeddings(cal_dataset) @@ -240,7 +240,7 @@ def test_forward_returns_predset(self): # Calibrate cal_indices = [0, 1, 2, 3] - cal_dataset = torch.utils.data.Subset(self.dataset, cal_indices) + cal_dataset = self.dataset.subset(cal_indices) # Extract embeddings cal_embeddings = self._get_embeddings(cal_dataset) @@ -281,7 +281,7 @@ def test_prediction_sets_nonempty(self): # Calibrate on first 4 samples cal_indices = [0, 1, 2, 3] - cal_dataset = torch.utils.data.Subset(self.dataset, cal_indices) + cal_dataset = self.dataset.subset(cal_indices) # Extract embeddings cal_embeddings = self._get_embeddings(cal_dataset) @@ -295,7 +295,7 @@ def test_prediction_sets_nonempty(self): # Test on remaining samples test_indices = [4, 5] - test_dataset = torch.utils.data.Subset(self.dataset, test_indices) + test_dataset = self.dataset.subset(test_indices) test_loader = get_dataloader(test_dataset, batch_size=2, shuffle=False) with torch.no_grad(): @@ -359,7 +359,7 @@ def test_edge_case_empty_class(self): # Use subset that doesn't have all classes cal_indices = [0, 1] # Only classes 0 and 1 - cal_dataset = torch.utils.data.Subset(self.dataset, cal_indices) + cal_dataset = self.dataset.subset(cal_indices) # Extract embeddings cal_embeddings = self._get_embeddings(cal_dataset) From 8b6746951291e5274477f9d563d0a558512a708c Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 19:20:11 -0500 Subject: [PATCH 49/82] update sample dataset test --- tests/core/test_sample_dataset.py | 144 ++++++++++++++++++------------ 1 file changed, 85 insertions(+), 59 deletions(-) diff --git a/tests/core/test_sample_dataset.py b/tests/core/test_sample_dataset.py index 842e6a4e7..8dd4629c0 100644 --- a/tests/core/test_sample_dataset.py +++ b/tests/core/test_sample_dataset.py @@ -1,97 +1,123 @@ import unittest import pickle +import random from pyhealth.datasets.sample_dataset import create_sample_dataset -class TestSampleDataset(unittest.TestCase): +class TestSampleDatasetParity(unittest.TestCase): def setUp(self): + # Create a slightly larger dataset to make shuffling more obvious self.samples = [ - {"patient_id": "p1", "record_id": "r1", "feature": "a", "label": 1}, - {"patient_id": "p1", "record_id": "r2", "feature": "b", "label": 0}, - {"patient_id": "p2", "record_id": "r3", "feature": "c", "label": 1}, - {"patient_id": "p3", "record_id": "r4", "feature": "d", "label": 0}, - {"patient_id": "p3", "record_id": "r5", "feature": "e", "label": 1}, + {"patient_id": f"p{i}", "record_id": f"r{i}", "feature": i, "label": i % 2} + for i in range(20) ] self.input_schema = {"feature": "raw"} self.output_schema = {"label": "raw"} - def test_sample_dataset_subset_slice(self): - # Create SampleDataset (disk-based) - dataset = create_sample_dataset( - samples=self.samples, - input_schema=self.input_schema, - output_schema=self.output_schema, - in_memory=False - ) - - # Define a slice object - s = slice(1, 4) # Slice [1:4] -> 1, 2, 3 - - subset = dataset.subset(s) - - self.assertEqual(len(subset), 3) - - # Check content - subset_data = list(subset) - self.assertEqual(subset_data[0]["feature"], "b") - self.assertEqual(subset_data[1]["feature"], "c") - self.assertEqual(subset_data[2]["feature"], "d") - - def test_in_memory_sample_dataset_behavior(self): - # Create both datasets + def _get_datasets(self): ds_disk = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, in_memory=False ) - ds_mem = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, in_memory=True ) - - # 1. Test len + return ds_disk, ds_mem + + def test_len(self): + ds_disk, ds_mem = self._get_datasets() + self.assertEqual(len(ds_disk), 20) + self.assertEqual(len(ds_mem), 20) self.assertEqual(len(ds_disk), len(ds_mem)) - self.assertEqual(len(ds_disk), 5) - - # 2. Test iter - iter_disk = list(ds_disk) - iter_mem = list(ds_mem) - - for d, m in zip(iter_disk, iter_mem): - self.assertEqual(d["feature"], m["feature"]) - self.assertEqual(d["label"], m["label"]) - self.assertEqual(d["patient_id"], m["patient_id"]) - self.assertEqual(d["record_id"], m["record_id"]) - # 3. Test getitem - for i in range(len(ds_disk)): - d = ds_disk[i] - m = ds_mem[i] + def test_getitem(self): + ds_disk, ds_mem = self._get_datasets() + for i in range(len(self.samples)): + item_disk = ds_disk[i] + item_mem = ds_mem[i] + self.assertEqual(item_disk["feature"], item_mem["feature"]) + self.assertEqual(item_disk["label"], item_mem["label"]) + self.assertEqual(item_disk["patient_id"], item_mem["patient_id"]) + + def test_iter(self): + ds_disk, ds_mem = self._get_datasets() + list_disk = list(ds_disk) + list_mem = list(ds_mem) + + self.assertEqual(len(list_disk), len(list_mem)) + for d, m in zip(list_disk, list_mem): self.assertEqual(d["feature"], m["feature"]) - self.assertEqual(d["label"], m["label"]) - # 4. Test subset with list - indices = [0, 2, 4] + def test_subset_indices(self): + ds_disk, ds_mem = self._get_datasets() + indices = [0, 5, 10, 15, 19] + sub_disk = ds_disk.subset(indices) sub_mem = ds_mem.subset(indices) self.assertEqual(len(sub_disk), len(sub_mem)) + self.assertEqual(len(sub_disk), 5) + + list_disk = list(sub_disk) + list_mem = list(sub_mem) - for d, m in zip(sub_disk, sub_mem): + for d, m in zip(list_disk, list_mem): self.assertEqual(d["feature"], m["feature"]) - self.assertEqual(d["label"], m["label"]) - # 5. Test subset with slice - s = slice(0, 3) - sub_disk_slice = ds_disk.subset(s) - sub_mem_slice = ds_mem.subset(s) + def test_subset_slice(self): + ds_disk, ds_mem = self._get_datasets() + s = slice(2, 18, 2) - self.assertEqual(len(sub_disk_slice), len(sub_mem_slice)) - for d, m in zip(sub_disk_slice, sub_mem_slice): + sub_disk = ds_disk.subset(s) + sub_mem = ds_mem.subset(s) + + self.assertEqual(len(sub_disk), len(sub_mem)) + + list_disk = list(sub_disk) + list_mem = list(sub_mem) + + for d, m in zip(list_disk, list_mem): self.assertEqual(d["feature"], m["feature"]) + def test_set_shuffle(self): + ds_disk, ds_mem = self._get_datasets() + + # Test shuffle=True + ds_disk.set_shuffle(True) + ds_mem.set_shuffle(True) + + # Iterating should return all elements, but likely in different order than original + # and potentially different order between disk and mem (implementation detail) + # But the set of elements should be identical. + + items_disk = list(ds_disk) + items_mem = list(ds_mem) + + self.assertEqual(len(items_disk), 20) + self.assertEqual(len(items_mem), 20) + + # Check that we have the same set of features + features_disk = sorted([x["feature"] for x in items_disk]) + features_mem = sorted([x["feature"] for x in items_mem]) + features_orig = sorted([x["feature"] for x in self.samples]) + + self.assertEqual(features_disk, features_orig) + self.assertEqual(features_mem, features_orig) + + # Test shuffle=False resets to original order + ds_disk.set_shuffle(False) + ds_mem.set_shuffle(False) + + items_disk_ordered = list(ds_disk) + items_mem_ordered = list(ds_mem) + + for i in range(20): + self.assertEqual(items_disk_ordered[i]["feature"], i) + self.assertEqual(items_mem_ordered[i]["feature"], i) + if __name__ == "__main__": unittest.main() From 29908db9b176d0d25d775a31d0a515634ebce11b Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 19:51:37 -0500 Subject: [PATCH 50/82] Add deps --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 49d0f828d..bf45906a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "pandarallel~=1.6.5", "pydantic~=2.11.7", "litdata~=0.2.58", + "pyarrow~=22.0.0", ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] From 24facac2a92efe04d6f9f40b45bd80bd028b61be Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 7 Dec 2025 19:57:38 -0500 Subject: [PATCH 51/82] Fix test --- pyhealth/datasets/base_dataset.py | 6 +++++- pyhealth/tasks/medical_coding.py | 1 + tests/core/test_processor_schemas.py | 7 ++++--- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index af3109009..71ad44b71 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -330,6 +330,9 @@ def load_table(self, table_name: str) -> pl.LazyFrame: logger.info(f"Scanning table: {table_name} from {csv_path}") df = scan_csv_gz_or_csv_tsv(csv_path) + # Convert column names to lowercase before calling preprocess_func + df = df.rename(lambda col: col.lower()) + # Check if there is a preprocessing function for this table preprocess_func = getattr(self, f"preprocess_{table_name}", None) if preprocess_func is not None: @@ -344,6 +347,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame: other_csv_path = clean_path(other_csv_path) logger.info(f"Joining with table: {other_csv_path}") join_df = scan_csv_gz_or_csv_tsv(other_csv_path) + join_df = join_df.rename(lambda col: col.lower()) join_key = join_cfg.on columns = join_cfg.columns how = join_cfg.how @@ -386,7 +390,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame: # Flatten attribute columns with event_type prefix attribute_columns = [ - pl.col(attr).alias(f"{table_name}/{attr}") for attr in attribute_cols + pl.col(attr.lower()).alias(f"{table_name}/{attr}") for attr in attribute_cols ] event_frame = df.select(base_columns + attribute_columns) diff --git a/pyhealth/tasks/medical_coding.py b/pyhealth/tasks/medical_coding.py index 45b8d9cd1..23b3cb2e4 100644 --- a/pyhealth/tasks/medical_coding.py +++ b/pyhealth/tasks/medical_coding.py @@ -37,6 +37,7 @@ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: df.filter(pl.col("event_type") == "noteevents") .select("patient_id") .unique() + .collect() .to_series() ) ) diff --git a/tests/core/test_processor_schemas.py b/tests/core/test_processor_schemas.py index 54aa3618c..0651297e2 100644 --- a/tests/core/test_processor_schemas.py +++ b/tests/core/test_processor_schemas.py @@ -13,7 +13,7 @@ import numpy as np from pyhealth.datasets import MIMIC3Dataset -from pyhealth.datasets.sample_dataset import SampleDataset +from pyhealth.datasets.sample_dataset import create_sample_dataset from pyhealth.processors import TextProcessor, MultiLabelProcessor, TimeseriesProcessor from pyhealth.tasks.medical_coding import MIMIC3ICD9Coding from pyhealth.tasks.base_task import BaseTask @@ -116,6 +116,7 @@ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: df.filter(pl.col("event_type") == "noteevents") .select("patient_id") .unique() + .collect() .to_series() ) ) @@ -326,7 +327,7 @@ def __call__(self, patient: Patient) -> List[Dict]: } task = TestTimeseriesTask() - sample_dataset = SampleDataset( + sample_dataset = create_sample_dataset( samples=samples, input_schema=task.input_schema, output_schema=task.output_schema, @@ -379,7 +380,7 @@ def __call__(self, patient: Patient) -> List[Dict]: } task = TestTimeseriesTask() - sample_dataset = SampleDataset( + sample_dataset = create_sample_dataset( samples=samples, input_schema=task.input_schema, output_schema=task.output_schema, From 7747529d0cba641cf5f2e62b61b08207d1b3bb00 Mon Sep 17 00:00:00 2001 From: jhnwu3 Date: Sun, 7 Dec 2025 21:14:31 -0600 Subject: [PATCH 52/82] commit for fixing model docs --- pyhealth/models/cnn.py | 8 ++-- pyhealth/models/gamenet.py | 4 +- pyhealth/models/gnn.py | 56 ++++++++++++++++++++++---- pyhealth/models/logistic_regression.py | 8 ++-- pyhealth/models/micron.py | 7 ++-- pyhealth/models/mlp.py | 8 ++-- pyhealth/models/molerec.py | 2 +- pyhealth/models/retain.py | 8 ++-- pyhealth/models/rnn.py | 47 ++++++++++++++++++--- pyhealth/models/stagenet.py | 8 ++-- pyhealth/models/transformer.py | 8 ++-- 11 files changed, 121 insertions(+), 43 deletions(-) diff --git a/pyhealth/models/cnn.py b/pyhealth/models/cnn.py index dc9853d08..9e69de3bd 100644 --- a/pyhealth/models/cnn.py +++ b/pyhealth/models/cnn.py @@ -159,7 +159,7 @@ class CNN(BaseModel): **kwargs: Additional keyword arguments forwarded to :class:`CNNLayer`. Example: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "p0", @@ -176,7 +176,7 @@ class CNN(BaseModel): ... "label": 0, ... }, ... ] - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={"conditions": "sequence", "labs": "tensor"}, ... output_schema={"label": "binary"}, @@ -326,7 +326,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": - from pyhealth.datasets import SampleDataset + from pyhealth.datasets import create_sample_dataset from pyhealth.datasets import get_dataloader samples = [ @@ -348,7 +348,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: input_schema = {"conditions": "sequence", "labs": "tensor"} output_schema = {"label": "binary"} - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, # type: ignore[arg-type] output_schema=output_schema, # type: ignore[arg-type] diff --git a/pyhealth/models/gamenet.py b/pyhealth/models/gamenet.py index dd5e27314..46afe057f 100644 --- a/pyhealth/models/gamenet.py +++ b/pyhealth/models/gamenet.py @@ -241,14 +241,14 @@ class GAMENet(BaseModel): **kwargs: other parameters for the GAMENet layer. Examples: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> from pyhealth.tasks import drug_recommendation_mimic3_fn >>> from pyhealth.models import GAMENet >>> from pyhealth.datasets import split_by_patient, get_dataloader >>> from pyhealth.trainer import Trainer >>> >>> # Load MIMIC-III dataset - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples=[ ... { ... "patient_id": "patient-0", diff --git a/pyhealth/models/gnn.py b/pyhealth/models/gnn.py index c4a0eed00..9ef323cb3 100644 --- a/pyhealth/models/gnn.py +++ b/pyhealth/models/gnn.py @@ -324,9 +324,29 @@ class GCN(BaseModel): num_layers: Number of GCN layers. Defaults to 2. Examples: - >>> from pyhealth.datasets import SampleDataset, get_dataloader - >>> samples = [...] - >>> dataset = SampleDataset(...) + >>> from pyhealth.datasets import create_sample_dataset, get_dataloader + >>> samples = [ + ... { + ... "patient_id": "patient-0", + ... "visit_id": "visit-0", + ... "diagnoses": ["A", "B", "C"], + ... "procedures": ["X", "Y"], + ... "label": 1, + ... }, + ... { + ... "patient_id": "patient-1", + ... "visit_id": "visit-0", + ... "diagnoses": ["D", "E"], + ... "procedures": ["Z"], + ... "label": 0, + ... }, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"diagnoses": "sequence", "procedures": "sequence"}, + ... output_schema={"label": "binary"}, + ... dataset_name="test", + ... ) >>> model = GCN(dataset=dataset) >>> loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> batch = next(iter(loader)) @@ -503,9 +523,29 @@ class GAT(BaseModel): num_layers: Number of GAT layers. Defaults to 2. Examples: - >>> from pyhealth.datasets import SampleDataset, get_dataloader - >>> samples = [...] - >>> dataset = SampleDataset(...) + >>> from pyhealth.datasets import create_sample_dataset, get_dataloader + >>> samples = [ + ... { + ... "patient_id": "patient-0", + ... "visit_id": "visit-0", + ... "diagnoses": ["A", "B", "C"], + ... "procedures": ["X", "Y"], + ... "label": 1, + ... }, + ... { + ... "patient_id": "patient-1", + ... "visit_id": "visit-0", + ... "diagnoses": ["D", "E"], + ... "procedures": ["Z"], + ... "label": 0, + ... }, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"diagnoses": "sequence", "procedures": "sequence"}, + ... output_schema={"label": "binary"}, + ... dataset_name="test", + ... ) >>> model = GAT(dataset=dataset) >>> loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> batch = next(iter(loader)) @@ -666,7 +706,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": - from pyhealth.datasets import SampleDataset, get_dataloader + from pyhealth.datasets import create_sample_dataset, get_dataloader samples = [ { @@ -688,7 +728,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: input_schema = {"diagnoses": "sequence", "procedures": "sequence"} output_schema = {"label": "binary"} - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, output_schema=output_schema, diff --git a/pyhealth/models/logistic_regression.py b/pyhealth/models/logistic_regression.py index 9b99c491e..8155d101f 100644 --- a/pyhealth/models/logistic_regression.py +++ b/pyhealth/models/logistic_regression.py @@ -29,7 +29,7 @@ class LogisticRegression(BaseModel): **kwargs: other parameters (for compatibility). Examples: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", @@ -49,7 +49,7 @@ class LogisticRegression(BaseModel): >>> input_schema = {"conditions": "sequence", ... "procedures": "tensor"} >>> output_schema = {"label": "binary"} - >>> dataset = SampleDataset(samples=samples, + >>> dataset = create_sample_dataset(samples=samples, ... input_schema=input_schema, ... output_schema=output_schema, ... dataset_name="test") @@ -211,7 +211,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": - from pyhealth.datasets import SampleDataset + from pyhealth.datasets import create_sample_dataset samples = [ { @@ -238,7 +238,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: output_schema = {"label": "binary"} # binary classification # dataset - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, output_schema=output_schema, diff --git a/pyhealth/models/micron.py b/pyhealth/models/micron.py index fb9269bf7..ffee2612c 100644 --- a/pyhealth/models/micron.py +++ b/pyhealth/models/micron.py @@ -159,7 +159,7 @@ class MICRON(BaseModel): - output_schema should include 'drugs' as a multilabel/multihot feature Example: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", @@ -169,10 +169,11 @@ class MICRON(BaseModel): ... "drugs": ["metformin", "lisinopril"] ... } ... ] - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={"conditions": "sequence", "procedures": "sequence"}, - ... output_schema={"drugs": "multilabel"} + ... output_schema={"drugs": "multilabel"}, + ... dataset_name="test", ... ) >>> model = MICRON(dataset=dataset) """ diff --git a/pyhealth/models/mlp.py b/pyhealth/models/mlp.py index 413fb6f34..598b45778 100644 --- a/pyhealth/models/mlp.py +++ b/pyhealth/models/mlp.py @@ -55,7 +55,7 @@ class MLP(BaseModel): **kwargs: other parameters for the MLP layer. Examples: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", @@ -75,7 +75,7 @@ class MLP(BaseModel): >>> input_schema = {"conditions": "sequence", ... "procedures": "timeseries"} >>> output_schema = {"label": "binary"} - >>> dataset = SampleDataset(samples=samples, + >>> dataset = create_sample_dataset(samples=samples, ... input_schema=input_schema, ... output_schema=output_schema, ... dataset_name="test") @@ -364,7 +364,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": - from pyhealth.datasets import SampleDataset + from pyhealth.datasets import create_sample_dataset samples = [ { @@ -391,7 +391,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: output_schema = {"label": "binary"} # binary classification # dataset - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, output_schema=output_schema, diff --git a/pyhealth/models/molerec.py b/pyhealth/models/molerec.py index 4cf6c5a28..68561b548 100644 --- a/pyhealth/models/molerec.py +++ b/pyhealth/models/molerec.py @@ -498,7 +498,7 @@ class MoleRec(BaseModel): ... ] >>> >>> # dataset - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={ ... "conditions": "nested_sequence", diff --git a/pyhealth/models/retain.py b/pyhealth/models/retain.py index b9aac4309..b272995ec 100644 --- a/pyhealth/models/retain.py +++ b/pyhealth/models/retain.py @@ -132,7 +132,7 @@ class RETAIN(BaseModel): **kwargs: other parameters for the RETAIN layer. Examples: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", @@ -149,7 +149,7 @@ class RETAIN(BaseModel): ... "label": 0, ... }, ... ] - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={ ... "conditions": "nested_sequence", @@ -281,7 +281,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": - from pyhealth.datasets import SampleDataset + from pyhealth.datasets import create_sample_dataset samples = [ { @@ -303,7 +303,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: ] # dataset - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema={ "conditions": "nested_sequence", diff --git a/pyhealth/models/rnn.py b/pyhealth/models/rnn.py index fe6d7aff7..e89d58bd1 100644 --- a/pyhealth/models/rnn.py +++ b/pyhealth/models/rnn.py @@ -137,13 +137,50 @@ class RNN(BaseModel): Args: dataset (SampleDataset): the dataset to train the model. It is used to query certain information such as the set of all tokens. - feature_keys (List[str]): list of keys in samples to use as features, - e.g. ["conditions", "procedures"]. - label_key (str): key in samples to use as label (e.g., "drugs"). - mode (str): one of "binary", "multiclass", or "multilabel". embedding_dim (int): the embedding dimension. Default is 128. hidden_dim (int): the hidden dimension. Default is 128. - **kwargs: other parameters for the RNN layer. + **kwargs: other parameters for the RNN layer (e.g., rnn_type, num_layers, dropout, bidirectional). + + Examples: + >>> from pyhealth.datasets import create_sample_dataset + >>> samples = [ + ... { + ... "patient_id": "patient-0", + ... "visit_id": "visit-0", + ... "conditions": ["cond-33", "cond-86", "cond-80"], + ... "procedures": ["proc-12", "proc-45"], + ... "label": 1, + ... }, + ... { + ... "patient_id": "patient-1", + ... "visit_id": "visit-1", + ... "conditions": ["cond-12", "cond-52"], + ... "procedures": ["proc-23"], + ... "label": 0, + ... }, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"conditions": "sequence", "procedures": "sequence"}, + ... output_schema={"label": "binary"}, + ... dataset_name="test" + ... ) + >>> + >>> from pyhealth.datasets import get_dataloader + >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + >>> + >>> model = RNN(dataset=dataset, embedding_dim=128, hidden_dim=64) + >>> + >>> data_batch = next(iter(train_loader)) + >>> + >>> ret = model(**data_batch) + >>> print(ret) + { + 'loss': tensor(...), + 'y_prob': tensor(...), + 'y_true': tensor(...), + 'logit': tensor(...) + } """ def __init__( diff --git a/pyhealth/models/stagenet.py b/pyhealth/models/stagenet.py index 884e1262a..1892e8f3b 100644 --- a/pyhealth/models/stagenet.py +++ b/pyhealth/models/stagenet.py @@ -285,7 +285,7 @@ class StageNet(BaseModel): **kwargs: other parameters for the StageNet layer. Examples: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", @@ -316,7 +316,7 @@ class StageNet(BaseModel): ... ] >>> >>> # dataset - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={ ... "codes": "stagenet", @@ -616,7 +616,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": - from pyhealth.datasets import SampleDataset + from pyhealth.datasets import create_sample_dataset samples = [ { @@ -654,7 +654,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: ] # dataset - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema={ "codes": "stagenet", diff --git a/pyhealth/models/transformer.py b/pyhealth/models/transformer.py index d1e87abea..9f0feeb38 100644 --- a/pyhealth/models/transformer.py +++ b/pyhealth/models/transformer.py @@ -353,7 +353,7 @@ class Transformer(BaseModel): num_layers (int): number of transformer blocks per feature stream. Examples: - >>> from pyhealth.datasets import SampleDataset, get_dataloader + >>> from pyhealth.datasets import create_sample_dataset, get_dataloader >>> samples = [ ... { ... "patient_id": "patient-0", @@ -372,7 +372,7 @@ class Transformer(BaseModel): ... ] >>> input_schema = {"diagnoses": "sequence", "procedures": "sequence"} >>> output_schema = {"label": "binary"} - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples, ... input_schema, ... output_schema, @@ -661,7 +661,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": - from pyhealth.datasets import SampleDataset, get_dataloader + from pyhealth.datasets import create_sample_dataset, get_dataloader samples = [ { @@ -686,7 +686,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: } output_schema: Dict[str, Union[str, type[FeatureProcessor]]] = {"label": "binary"} - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, output_schema=output_schema, From ea47052b436dc8c1e88c4614ec7835b6ea582153 Mon Sep 17 00:00:00 2001 From: jhnwu3 Date: Sun, 7 Dec 2025 22:04:11 -0600 Subject: [PATCH 53/82] fix adacare docstrings --- pyhealth/models/adacare.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyhealth/models/adacare.py b/pyhealth/models/adacare.py index 730ff36ec..5e4ee806b 100644 --- a/pyhealth/models/adacare.py +++ b/pyhealth/models/adacare.py @@ -289,7 +289,7 @@ class AdaCare(BaseModel): Examples: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", @@ -321,7 +321,7 @@ class AdaCare(BaseModel): ... "label": 0, ... }, ... ] - >>> dataset = SampleDataset(samples=samples, + >>> dataset = create_sample_dataset(samples=samples, ... input_schema={ ... 'vector': 'nested_sequence_floats', ... 'list_codes': 'sequence', @@ -329,7 +329,8 @@ class AdaCare(BaseModel): ... 'list_vectors': 'nested_sequence_floats', ... 'list_list_vectors': 'deep_nested_sequence_floats' ... }, - ... output_schema={'label': 'binary'} + ... output_schema={'label': 'binary'}, + ... dataset_name='test' ... ) >>> >>> from pyhealth.models import AdaCare From 4543decc763a0ba4b47d0f0a4dbd80c5b48cc683 Mon Sep 17 00:00:00 2001 From: John Wu Date: Mon, 8 Dec 2025 15:31:01 -0600 Subject: [PATCH 54/82] fix for python 3.10 override typing incompatibility, but still struggling with problems with out of memory? --- pyhealth/datasets/base_dataset.py | 12 +++- pyhealth/datasets/mimic4.py | 73 +++++++++++++++++------ pyhealth/datasets/sample_dataset.py | 10 +--- pyhealth/processors/sequence_processor.py | 4 +- 4 files changed, 69 insertions(+), 30 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 71ad44b71..1f171d46f 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -258,6 +258,11 @@ def cache_dir(self) -> Path: cache_dir.mkdir(parents=True, exist_ok=True) print(f"No cache_dir provided. Using default cache dir: {cache_dir}") self._cache_dir = cache_dir + else: + # Ensure the explicitly provided cache_dir exists + cache_dir = Path(self._cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + self._cache_dir = cache_dir return Path(self._cache_dir) @property @@ -390,7 +395,8 @@ def load_table(self, table_name: str) -> pl.LazyFrame: # Flatten attribute columns with event_type prefix attribute_columns = [ - pl.col(attr.lower()).alias(f"{table_name}/{attr}") for attr in attribute_cols + pl.col(attr.lower()).alias(f"{table_name}/{attr}") + for attr in attribute_cols ] event_frame = df.select(base_columns + attribute_columns) @@ -523,6 +529,10 @@ def set_task( if cache_dir is None: cache_dir = self.cache_dir / "tasks" / task.task_name cache_dir.mkdir(parents=True, exist_ok=True) + else: + # Ensure the explicitly provided cache_dir exists + cache_dir = Path(cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) path = Path(cache_dir) diff --git a/pyhealth/datasets/mimic4.py b/pyhealth/datasets/mimic4.py index e3d2340c1..9b48985e3 100644 --- a/pyhealth/datasets/mimic4.py +++ b/pyhealth/datasets/mimic4.py @@ -1,13 +1,14 @@ import logging import os import warnings -from typing import Dict, List, Optional, override +from typing import Dict, List, Optional import pandas as pd import polars as pl try: import psutil + HAS_PSUTIL = True except ImportError: HAS_PSUTIL = False @@ -39,7 +40,7 @@ class MIMIC4EHRDataset(BaseDataset): tables (List[str]): A list of tables to be included in the dataset. dataset_name (Optional[str]): The name of the dataset. config_path (Optional[str]): The path to the configuration file. - """ + """ def __init__( self, @@ -47,10 +48,13 @@ def __init__( tables: List[str], dataset_name: str = "mimic4_ehr", config_path: Optional[str] = None, - **kwargs + cache_dir: Optional[str] = None, + **kwargs, ): if config_path is None: - config_path = os.path.join(os.path.dirname(__file__), "configs", "mimic4_ehr.yaml") + config_path = os.path.join( + os.path.dirname(__file__), "configs", "mimic4_ehr.yaml" + ) logger.info(f"Using default EHR config: {config_path}") log_memory_usage(f"Before initializing {dataset_name}") @@ -61,7 +65,8 @@ def __init__( tables=tables, dataset_name=dataset_name, config_path=config_path, - **kwargs + cache_dir=cache_dir, + **kwargs, ) log_memory_usage(f"After initializing {dataset_name}") @@ -86,10 +91,13 @@ def __init__( tables: List[str], dataset_name: str = "mimic4_note", config_path: Optional[str] = None, - **kwargs + cache_dir: Optional[str] = None, + **kwargs, ): if config_path is None: - config_path = os.path.join(os.path.dirname(__file__), "configs", "mimic4_note.yaml") + config_path = os.path.join( + os.path.dirname(__file__), "configs", "mimic4_note.yaml" + ) logger.info(f"Using default note config: {config_path}") if "discharge" in tables: warnings.warn( @@ -109,7 +117,8 @@ def __init__( tables=tables, dataset_name=dataset_name, config_path=config_path, - **kwargs + cache_dir=cache_dir, + **kwargs, ) log_memory_usage(f"After initializing {dataset_name}") @@ -134,10 +143,13 @@ def __init__( tables: List[str], dataset_name: str = "mimic4_cxr", config_path: Optional[str] = None, - **kwargs + cache_dir: Optional[str] = None, + **kwargs, ): if config_path is None: - config_path = os.path.join(os.path.dirname(__file__), "configs", "mimic4_cxr.yaml") + config_path = os.path.join( + os.path.dirname(__file__), "configs", "mimic4_cxr.yaml" + ) logger.info(f"Using default CXR config: {config_path}") self.prepare_metadata(root) log_memory_usage(f"Before initializing {dataset_name}") @@ -146,12 +158,15 @@ def __init__( tables=tables, dataset_name=dataset_name, config_path=config_path, - **kwargs + cache_dir=cache_dir, + **kwargs, ) log_memory_usage(f"After initializing {dataset_name}") def prepare_metadata(self, root: str) -> None: - metadata = pd.read_csv(os.path.join(root, "mimic-cxr-2.0.0-metadata.csv.gz"), dtype=str) + metadata = pd.read_csv( + os.path.join(root, "mimic-cxr-2.0.0-metadata.csv.gz"), dtype=str + ) def process_studytime(x): # reformat studytime to be 6 digits (e.g. 123.002 -> 000123 which is 12:30:00) @@ -160,6 +175,7 @@ def process_studytime(x): return f"{int(x):06d}" except Exception: return x + metadata["StudyTime"] = metadata["StudyTime"].apply(process_studytime) def process_image_path(x): @@ -168,10 +184,15 @@ def process_image_path(x): folder = subject_id[:3] study_id = "s" + x["study_id"] dicom_id = x["dicom_id"] - return os.path.join(root, "files", folder, subject_id, study_id, f"{dicom_id}.jpg") + return os.path.join( + root, "files", folder, subject_id, study_id, f"{dicom_id}.jpg" + ) + metadata["image_path"] = metadata.apply(process_image_path, axis=1) - metadata.to_csv(os.path.join(root, "mimic-cxr-2.0.0-metadata-pyhealth.csv"), index=False) + metadata.to_csv( + os.path.join(root, "mimic-cxr-2.0.0-metadata-pyhealth.csv"), index=False + ) return @@ -211,6 +232,7 @@ def __init__( cxr_config_path: Optional[str] = None, dataset_name: str = "mimic4", dev: bool = False, + cache_dir: Optional[str] = None, ): log_memory_usage("Starting MIMIC4Dataset init") @@ -229,6 +251,7 @@ def __init__( dataset_name=dataset_name, config_path=None, dev=dev, + cache_dir=cache_dir, ) # Initialize child datasets @@ -236,37 +259,51 @@ def __init__( # Initialize EHR dataset if root is provided if ehr_root: - logger.info(f"Initializing MIMIC4EHRDataset with tables: {ehr_tables} (dev mode: {dev})") + logger.info( + f"Initializing MIMIC4EHRDataset with tables: {ehr_tables} (dev mode: {dev})" + ) + ehr_cache_dir = None if cache_dir is None else f"{cache_dir}/ehr" self.sub_datasets["ehr"] = MIMIC4EHRDataset( root=ehr_root, tables=ehr_tables, config_path=ehr_config_path, + cache_dir=ehr_cache_dir, + dev=dev, ) log_memory_usage("After EHR dataset initialization") # Initialize Notes dataset if root is provided if note_root is not None and note_tables: - logger.info(f"Initializing MIMIC4NoteDataset with tables: {note_tables} (dev mode: {dev})") + logger.info( + f"Initializing MIMIC4NoteDataset with tables: {note_tables} (dev mode: {dev})" + ) + note_cache_dir = None if cache_dir is None else f"{cache_dir}/note" self.sub_datasets["note"] = MIMIC4NoteDataset( root=note_root, tables=note_tables, config_path=note_config_path, + cache_dir=note_cache_dir, + dev=dev, ) log_memory_usage("After Note dataset initialization") # Initialize CXR dataset if root is provided if cxr_root is not None: - logger.info(f"Initializing MIMIC4CXRDataset with tables: {cxr_tables} (dev mode: {dev})") + logger.info( + f"Initializing MIMIC4CXRDataset with tables: {cxr_tables} (dev mode: {dev})" + ) + cxr_cache_dir = None if cache_dir is None else f"{cache_dir}/cxr" self.sub_datasets["cxr"] = MIMIC4CXRDataset( root=cxr_root, tables=cxr_tables, config_path=cxr_config_path, + cache_dir=cxr_cache_dir, + dev=dev, ) log_memory_usage("After CXR dataset initialization") log_memory_usage("Completed MIMIC4Dataset init") - @override def load_data(self) -> pl.LazyFrame: """ Combines data from all initialized sub-datasets into a unified global event dataframe. diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 3cc15d7ca..e14e02ca9 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -2,7 +2,7 @@ from pathlib import Path import pickle import tempfile -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type, override +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type import inspect import random from bisect import bisect_right @@ -279,7 +279,6 @@ def __init__( self.patient_to_index = metadata["patient_to_index"] self.record_to_index = metadata["record_to_index"] - @override def __str__(self) -> str: """Returns a string representation of the dataset. @@ -420,11 +419,9 @@ def __init__( self._shuffle = False - @override def set_shuffle(self, shuffle: bool) -> None: self._shuffle = shuffle - @override def __len__(self) -> int: """Returns the number of samples in the dataset. @@ -433,7 +430,6 @@ def __len__(self) -> int: """ return len(self._data) - @override def __getitem__(self, index: int) -> Dict[str, Any]: # type: ignore """Retrieve a processed sample by index. @@ -445,7 +441,6 @@ def __getitem__(self, index: int) -> Dict[str, Any]: # type: ignore """ return self._data[index] - @override def __iter__(self) -> Iterable[Dict[str, Any]]: # type: ignore """Returns an iterator over all samples in the dataset. @@ -459,13 +454,12 @@ def __iter__(self) -> Iterable[Dict[str, Any]]: # type: ignore else: return iter(self._data) - @override def subset(self, indices: Union[Sequence[int], slice]) -> SampleDataset: if isinstance(indices, slice): samples = self._data[indices] else: samples = [self._data[i] for i in indices] - + new_dataset = copy.deepcopy(self) new_dataset._data = samples return new_dataset diff --git a/pyhealth/processors/sequence_processor.py b/pyhealth/processors/sequence_processor.py index 3f889051a..9bcb0a325 100644 --- a/pyhealth/processors/sequence_processor.py +++ b/pyhealth/processors/sequence_processor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, override +from typing import Any, Dict, List import torch @@ -20,7 +20,6 @@ def __init__(self): self.code_vocab: Dict[Any, int] = {"": 0} self._next_index = 1 - @override def fit(self, samples: List[Dict[str, Any]], field: str) -> None: for sample in samples: for token in sample[field]: @@ -32,7 +31,6 @@ def fit(self, samples: List[Dict[str, Any]], field: str) -> None: self.code_vocab[""] = len(self.code_vocab) - @override def process(self, value: Any) -> torch.Tensor: """Process token value(s) into tensor of indices. From 73fb32150be18d19e6e3976072f8c6d7e91d8c24 Mon Sep 17 00:00:00 2001 From: John Wu Date: Mon, 8 Dec 2025 15:31:30 -0600 Subject: [PATCH 55/82] organize for benchmarking scripts --- examples/benchmark_perf/benchmark_pandas.py | 407 ++++++++++++++++++ .../benchmark_perf/benchmark_workers_1.py | 188 ++++++++ .../benchmark_perf/benchmark_workers_4.py | 188 ++++++++ examples/benchmark_perf/memtest.py | 105 +++++ 4 files changed, 888 insertions(+) create mode 100644 examples/benchmark_perf/benchmark_pandas.py create mode 100644 examples/benchmark_perf/benchmark_workers_1.py create mode 100644 examples/benchmark_perf/benchmark_workers_4.py create mode 100644 examples/benchmark_perf/memtest.py diff --git a/examples/benchmark_perf/benchmark_pandas.py b/examples/benchmark_perf/benchmark_pandas.py new file mode 100644 index 000000000..889fda14d --- /dev/null +++ b/examples/benchmark_perf/benchmark_pandas.py @@ -0,0 +1,407 @@ +""" +Benchmark script for MIMIC-IV mortality prediction using pandas +(analogous to PyHealth task). + +This benchmark mimics the MortalityPredictionStageNetMIMIC4 task: +1. Creates PATIENT-LEVEL samples (not visit-level) +2. Aggregates all admissions per patient +3. Combines ICD codes (diagnoses + procedures) across all visits +4. Extracts lab events in 10-dimensional vectors per timestamp +5. Calculates time intervals between consecutive admissions + +Lab Categories (10 dimensions): +- Sodium, Potassium, Chloride, Bicarbonate, Glucose +- Calcium, Magnesium, Anion Gap, Osmolality, Phosphate +""" + +import time +import os +import threading +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd +import psutil + + +PEAK_MEM_USAGE = 0 +SELF_PROC = psutil.Process(os.getpid()) + + +def track_mem(): + """Background thread to track peak memory usage.""" + global PEAK_MEM_USAGE + while True: + m = SELF_PROC.memory_info().rss + if m > PEAK_MEM_USAGE: + PEAK_MEM_USAGE = m + time.sleep(0.1) + + +# Lab item organization by category (matches MortalityPredictionStageNetMIMIC4) +LAB_CATEGORIES = { + "Sodium": ["50824", "52455", "50983", "52623"], + "Potassium": ["50822", "52452", "50971", "52610"], + "Chloride": ["50806", "52434", "50902", "52535"], + "Bicarbonate": ["50803", "50804"], + "Glucose": ["50809", "52027", "50931", "52569"], + "Calcium": ["50808", "51624"], + "Magnesium": ["50960"], + "Anion Gap": ["50868", "52500"], + "Osmolality": ["52031", "50964", "51701"], + "Phosphate": ["50970"], +} + +# Ordered list of category names (defines vector dimension order) +LAB_CATEGORY_NAMES = [ + "Sodium", + "Potassium", + "Chloride", + "Bicarbonate", + "Glucose", + "Calcium", + "Magnesium", + "Anion Gap", + "Osmolality", + "Phosphate", +] + +# Flat list of all lab item IDs for filtering +LABITEMS = [item for itemids in LAB_CATEGORIES.values() for item in itemids] + + +def process_patient_mortality( + subject_id: int, + patients_df: pd.DataFrame, + admissions_df: pd.DataFrame, + diagnoses_df: pd.DataFrame, + procedures_df: pd.DataFrame, + labevents_df: pd.DataFrame, +) -> Optional[Dict[str, Any]]: + """Process a single patient for mortality prediction task. + + Creates ONE patient-level sample by aggregating all admissions, + ICD codes, and lab events. + + Args: + subject_id: Patient ID + patients_df: Patient demographics + admissions_df: Admission records + diagnoses_df: Diagnosis ICD codes + procedures_df: Procedure ICD codes + labevents_df: Lab event measurements + + Returns: + Dictionary with patient sample or None if patient doesn't qualify + """ + # Get patient demographics + patient_demo = patients_df[patients_df["subject_id"] == subject_id] + if len(patient_demo) == 0: + return None + + # Skip if under 18 + anchor_age = patient_demo.iloc[0]["anchor_age"] + if anchor_age < 18: + return None + + # Get all admissions for this patient, sorted by time + patient_admissions = admissions_df[ + admissions_df["subject_id"] == subject_id + ].sort_values("admittime") + + if len(patient_admissions) < 1: + return None + + # Initialize aggregated data structures + all_icd_codes = [] # Nested list: [[visit1_codes], [visit2_codes], ...] + all_icd_times = [] # Time from previous admission per visit + all_lab_values = [] # List of 10D vectors + all_lab_times = [] # Time from admission start per measurement + + previous_admission_time = None + final_mortality = 0 + + # Process each admission + for _, admission in patient_admissions.iterrows(): + hadm_id = admission["hadm_id"] + admit_time = admission["admittime"] + discharge_time = admission["dischtime"] + + # Skip if invalid timestamps + if pd.isna(discharge_time) or discharge_time < admit_time: + continue + + # Calculate time from previous admission (hours) + if previous_admission_time is None: + time_from_previous = 0.0 + else: + time_from_previous = ( + admit_time - previous_admission_time + ).total_seconds() / 3600.0 + + previous_admission_time = admit_time + + # Update mortality label if this admission had mortality + if int(admission.get("hospital_expire_flag", 0)) == 1: + final_mortality = 1 + + # Get diagnosis codes for this admission + visit_diagnoses = diagnoses_df[diagnoses_df["hadm_id"] == hadm_id] + diagnoses_codes = visit_diagnoses["icd_code"].dropna().tolist() + + # Get procedure codes for this admission + visit_procedures = procedures_df[procedures_df["hadm_id"] == hadm_id] + procedures_codes = visit_procedures["icd_code"].dropna().tolist() + + # Combine diagnoses and procedures + visit_icd_codes = diagnoses_codes + procedures_codes + + if visit_icd_codes: + all_icd_codes.append(visit_icd_codes) + all_icd_times.append(time_from_previous) + + # Get lab events for this admission + admission_labs = labevents_df[ + (labevents_df["subject_id"] == subject_id) + & (labevents_df["hadm_id"] == hadm_id) + ] + + # Filter to relevant lab items + admission_labs = admission_labs[ + admission_labs["itemid"].astype(str).isin(LABITEMS) + ] + + if len(admission_labs) > 0: + # Parse storetime + admission_labs = admission_labs.copy() + admission_labs["storetime"] = pd.to_datetime(admission_labs["storetime"]) + + # Filter to valid times (before discharge) + admission_labs = admission_labs[ + admission_labs["storetime"] <= discharge_time + ] + + if len(admission_labs) > 0: + # Group by timestamp and create 10D vectors + unique_timestamps = sorted(admission_labs["storetime"].unique()) + + for lab_ts in unique_timestamps: + # Get all labs at this timestamp + ts_labs = admission_labs[admission_labs["storetime"] == lab_ts] + + # Create 10-dimensional vector + lab_vector = [] + for category_name in LAB_CATEGORY_NAMES: + category_itemids = LAB_CATEGORIES[category_name] + + # Find first matching value for this category + category_value = None + for itemid in category_itemids: + matching = ts_labs[ts_labs["itemid"].astype(str) == itemid] + if len(matching) > 0: + category_value = matching.iloc[0]["valuenum"] + break + + lab_vector.append(category_value) + + # Calculate time from admission start (hours) + time_from_admission = (lab_ts - admit_time).total_seconds() / 3600.0 + + all_lab_values.append(lab_vector) + all_lab_times.append(time_from_admission) + + # Skip if no lab events (required for this task) + if len(all_lab_values) == 0: + return None + + # Skip if no ICD codes + if len(all_icd_codes) == 0: + return None + + # Create patient-level sample + sample = { + "patient_id": subject_id, + "icd_codes": (all_icd_times, all_icd_codes), + "labs": (all_lab_times, all_lab_values), + "mortality": final_mortality, + "num_visits": len(all_icd_codes), + "num_lab_measurements": len(all_lab_values), + } + + return sample + + +def benchmark_mortality_prediction( + patients_df: pd.DataFrame, + admissions_df: pd.DataFrame, + diagnoses_df: pd.DataFrame, + procedures_df: pd.DataFrame, + labevents_df: pd.DataFrame, + n_patients: Optional[int] = None, +) -> Tuple[List[Dict[str, Any]], float]: + """ + Benchmark MIMIC-IV mortality prediction processing. + + Args: + patients_df: Patient demographics + admissions_df: Admissions dataframe + diagnoses_df: Diagnoses dataframe + procedures_df: Procedures dataframe + labevents_df: Lab events dataframe + n_patients: Number of patients to process (None = all patients) + + Returns: + Tuple of (list of samples, processing time in seconds) + """ + print("=" * 80) + print("BENCHMARK: Pandas Mortality Prediction (StageNet format)") + print("=" * 80) + + # Get patients to process + if n_patients is None: + patients_to_process = patients_df["subject_id"].tolist() + print(f"Processing all {len(patients_to_process)} patients...") + else: + patients_to_process = patients_df["subject_id"].head(n_patients).tolist() + print(f"Processing first {len(patients_to_process)} patients...") + + # Parse datetime columns + admissions_df = admissions_df.copy() + admissions_df["admittime"] = pd.to_datetime(admissions_df["admittime"]) + admissions_df["dischtime"] = pd.to_datetime(admissions_df["dischtime"]) + + # Start processing timer + start_time = time.perf_counter() + + samples = [] + processed_patients = 0 + + for subject_id in patients_to_process: + sample = process_patient_mortality( + subject_id, + patients_df, + admissions_df, + diagnoses_df, + procedures_df, + labevents_df, + ) + + if sample is not None: + samples.append(sample) + + processed_patients += 1 + if processed_patients % 100 == 0: + print(f"Processed {processed_patients} patients...") + + # End processing timer + processing_time = time.perf_counter() - start_time + + print("\nCompleted processing:") + print(f" - Total patients processed: {processed_patients}") + print(f" - Valid samples created: {len(samples)}") + print(f" - Processing time: {processing_time:.2f}s") + print("=" * 80) + + return samples, processing_time + + +def load_mimic_data(data_root: str = "/srv/local/data/MIMIC-IV/2.0/hosp"): + """Load MIMIC-IV tables needed for mortality prediction. + + Args: + data_root: Root directory for MIMIC-IV data + + Returns: + Tuple of dataframes: (patients, admissions, diagnoses, + procedures, labevents) + """ + print("Loading MIMIC-IV data tables...") + load_start = time.perf_counter() + + patients_df = pd.read_csv(f"{data_root}/patients.csv") + admissions_df = pd.read_csv(f"{data_root}/admissions.csv") + diagnoses_df = pd.read_csv(f"{data_root}/diagnoses_icd.csv") + procedures_df = pd.read_csv(f"{data_root}/procedures_icd.csv") + labevents_df = pd.read_csv(f"{data_root}/labevents.csv") + + load_time = time.perf_counter() - load_start + print(f"Data loaded in {load_time:.2f}s") + print(f" - Patients: {len(patients_df):,}") + print(f" - Admissions: {len(admissions_df):,}") + print(f" - Diagnoses: {len(diagnoses_df):,}") + print(f" - Procedures: {len(procedures_df):,}") + print(f" - Lab events: {len(labevents_df):,}") + print() + + return ( + patients_df, + admissions_df, + diagnoses_df, + procedures_df, + labevents_df, + ) + + +def main(): + """Main function to run the benchmark.""" + # Start memory tracking thread + mem_thread = threading.Thread(target=track_mem, daemon=True) + mem_thread.start() + + # Load data + data_root = "/srv/local/data/MIMIC-IV/2.0/hosp" + ( + patients_df, + admissions_df, + diagnoses_df, + procedures_df, + labevents_df, + ) = load_mimic_data(data_root) + + # Run benchmark (process all patients) + samples, processing_time = benchmark_mortality_prediction( + patients_df, + admissions_df, + diagnoses_df, + procedures_df, + labevents_df, + n_patients=None, # Change to a number to limit patients + ) + + # Get peak memory + peak_mem = PEAK_MEM_USAGE + + # Helper function for formatting size + def format_size(size_bytes): + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + # Save results + results_file = "benchmark_results_pandas.txt" + with open(results_file, "w") as f: + f.write("BENCHMARK RESULTS: Pandas Mortality Prediction\n") + f.write("=" * 80 + "\n") + f.write(f"Total samples: {len(samples)}\n") + f.write(f"Processing time: {processing_time:.2f}s\n") + f.write(f"Peak memory usage: {format_size(peak_mem)}\n") + f.write("=" * 80 + "\n") + + print(f"\n✓ Results saved to {results_file}") + print(f"Peak memory usage: {format_size(peak_mem)}") + + # Optional: Save samples for inspection + if samples: + print("\nExample sample (first patient):") + first_sample = samples[0] + print(f" Patient ID: {first_sample['patient_id']}") + print(f" Mortality: {first_sample['mortality']}") + print(f" Number of visits: {first_sample['num_visits']}") + print( + f" Number of lab measurements: " f"{first_sample['num_lab_measurements']}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/benchmark_perf/benchmark_workers_1.py b/examples/benchmark_perf/benchmark_workers_1.py new file mode 100644 index 000000000..4b1198861 --- /dev/null +++ b/examples/benchmark_perf/benchmark_workers_1.py @@ -0,0 +1,188 @@ +"""Benchmark script for MIMIC-IV mortality prediction with num_workers=1. + +This benchmark measures: +1. Time to load base dataset +2. Time to process task with num_workers=1 +3. Total processing time +4. Cache sizes +5. Peak memory usage (with optional memory limit) +""" + +import time +import os +import threading +from pathlib import Path +import psutil +from pyhealth.datasets import MIMIC4Dataset +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 + +try: + import resource + + HAS_RESOURCE = True +except ImportError: + HAS_RESOURCE = False + + +PEAK_MEM_USAGE = 0 +SELF_PROC = psutil.Process(os.getpid()) + + +def track_mem(): + """Background thread to track peak memory usage.""" + global PEAK_MEM_USAGE + while True: + m = SELF_PROC.memory_info().rss + if m > PEAK_MEM_USAGE: + PEAK_MEM_USAGE = m + time.sleep(0.1) + + +def set_memory_limit(max_memory_gb): + """Set hard memory limit for the process. + + Args: + max_memory_gb: Maximum memory in GB (e.g., 8 for 8GB) + + Note: + If limit is exceeded, the process will raise MemoryError. + Only works on Unix-like systems (Linux, macOS). + """ + if not HAS_RESOURCE: + print( + "Warning: resource module not available (Windows?). " + "Memory limit not enforced." + ) + return + + max_memory_bytes = int(max_memory_gb * 1024**3) + try: + resource.setrlimit(resource.RLIMIT_AS, (max_memory_bytes, max_memory_bytes)) + print(f"✓ Memory limit set to {max_memory_gb} GB") + except Exception as e: + print(f"Warning: Failed to set memory limit: {e}") + + +def get_directory_size(path): + """Calculate total size of a directory in bytes.""" + total = 0 + try: + for entry in Path(path).rglob("*"): + if entry.is_file(): + total += entry.stat().st_size + except Exception as e: + print(f"Error calculating size for {path}: {e}") + return total + + +def format_size(size_bytes): + """Format bytes to human-readable size.""" + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +def main(): + """Main benchmark function.""" + # Configuration + dev = True # Set to True for development/testing + enable_memory_limit = False # Set to True to enforce memory limit + max_memory_gb = 32 # Memory limit in GB (if enable_memory_limit=True) + + # Apply memory limit if enabled + if enable_memory_limit: + set_memory_limit(max_memory_gb) + + # Start memory tracking thread + mem_thread = threading.Thread(target=track_mem, daemon=True) + mem_thread.start() + + print("=" * 80) + print(f"BENCHMARK: num_workers=1, dev={dev}") + if enable_memory_limit: + print(f"Memory Limit: {max_memory_gb} GB (ENFORCED)") + else: + print("Memory Limit: None (unrestricted)") + print("=" * 80) + + # Define cache directories based on dev mode + cache_root = "./benchmark_cache/workers_1" + if dev: + cache_root += "_dev" + + # Track total time + total_start = time.time() # STEP 1: Load MIMIC-IV base dataset + print("\n[1/2] Loading MIMIC-IV base dataset...") + dataset_start = time.time() + + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + dev=dev, + cache_dir=f"{cache_root}/base_dataset", + ) + + dataset_time = time.time() - dataset_start + print(f"✓ Dataset loaded in {dataset_time:.2f} seconds") + + # STEP 2: Apply StageNet mortality prediction task with num_workers=1 + print("\n[2/2] Applying mortality prediction task (num_workers=1)...") + task_start = time.time() + + sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=1, + cache_dir=f"{cache_root}/task_samples", + ) + + task_time = time.time() - task_start + print(f"✓ Task processing completed in {task_time:.2f} seconds") + + # Measure cache sizes + print("\n[3/3] Measuring cache sizes...") + base_cache_dir = f"{cache_root}/base_dataset" + task_cache_dir = f"{cache_root}/task_samples" + + base_cache_size = get_directory_size(base_cache_dir) + task_cache_size = get_directory_size(task_cache_dir) + total_cache_size = base_cache_size + task_cache_size + + print(f"✓ Base dataset cache: {format_size(base_cache_size)}") + print(f"✓ Task samples cache: {format_size(task_cache_size)}") + print(f"✓ Total cache size: {format_size(total_cache_size)}") + + # Total time and peak memory + total_time = time.time() - total_start + peak_mem = PEAK_MEM_USAGE + + # Print summary + print("\n" + "=" * 80) + print("BENCHMARK RESULTS") + print("=" * 80) + print("Configuration:") + print(" - num_workers: 1") + print(f" - dev mode: {dev}") + print(f" - Total samples: {len(sample_dataset)}") + print("\nTiming:") + print(f" - Dataset loading: {dataset_time:.2f}s") + print(f" - Task processing: {task_time:.2f}s") + print(f" - Total time: {total_time:.2f}s") + print("\nCache Sizes:") + print(f" - Base dataset cache: {format_size(base_cache_size)}") + print(f" - Task samples cache: {format_size(task_cache_size)}") + print(f" - Total cache: {format_size(total_cache_size)}") + print("\nMemory:") + print(f" - Peak memory usage: {format_size(peak_mem)}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/benchmark_perf/benchmark_workers_4.py b/examples/benchmark_perf/benchmark_workers_4.py new file mode 100644 index 000000000..62e478561 --- /dev/null +++ b/examples/benchmark_perf/benchmark_workers_4.py @@ -0,0 +1,188 @@ +"""Benchmark script for MIMIC-IV mortality prediction with num_workers=4. + +This benchmark measures: +1. Time to load base dataset +2. Time to process task with num_workers=4 +3. Total processing time +4. Cache sizes +5. Peak memory usage (with optional memory limit) +""" + +import time +import os +import threading +from pathlib import Path +import psutil +from pyhealth.datasets import MIMIC4Dataset +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 + +try: + import resource + + HAS_RESOURCE = True +except ImportError: + HAS_RESOURCE = False + + +PEAK_MEM_USAGE = 0 +SELF_PROC = psutil.Process(os.getpid()) + + +def track_mem(): + """Background thread to track peak memory usage.""" + global PEAK_MEM_USAGE + while True: + m = SELF_PROC.memory_info().rss + if m > PEAK_MEM_USAGE: + PEAK_MEM_USAGE = m + time.sleep(0.1) + + +def set_memory_limit(max_memory_gb): + """Set hard memory limit for the process. + + Args: + max_memory_gb: Maximum memory in GB (e.g., 8 for 8GB) + + Note: + If limit is exceeded, the process will raise MemoryError. + Only works on Unix-like systems (Linux, macOS). + """ + if not HAS_RESOURCE: + print( + "Warning: resource module not available (Windows?). " + "Memory limit not enforced." + ) + return + + max_memory_bytes = int(max_memory_gb * 1024**3) + try: + resource.setrlimit(resource.RLIMIT_AS, (max_memory_bytes, max_memory_bytes)) + print(f"✓ Memory limit set to {max_memory_gb} GB") + except Exception as e: + print(f"Warning: Failed to set memory limit: {e}") + + +def get_directory_size(path): + """Calculate total size of a directory in bytes.""" + total = 0 + try: + for entry in Path(path).rglob("*"): + if entry.is_file(): + total += entry.stat().st_size + except Exception as e: + print(f"Error calculating size for {path}: {e}") + return total + + +def format_size(size_bytes): + """Format bytes to human-readable size.""" + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +def main(): + """Main benchmark function.""" + # Configuration + dev = True # Set to True for development/testing + enable_memory_limit = True # Set to True to enforce memory limit + max_memory_gb = 32 # Memory limit in GB (if enable_memory_limit=True) + + # Apply memory limit if enabled + if enable_memory_limit: + set_memory_limit(max_memory_gb) + + # Start memory tracking thread + mem_thread = threading.Thread(target=track_mem, daemon=True) + mem_thread.start() + + print("=" * 80) + print(f"BENCHMARK: num_workers=4, dev={dev}") + if enable_memory_limit: + print(f"Memory Limit: {max_memory_gb} GB (ENFORCED)") + else: + print("Memory Limit: None (unrestricted)") + print("=" * 80) + + # Define cache directories based on dev mode + cache_root = "./benchmark_cache/workers_4" + if dev: + cache_root += "_dev" + + # Track total time + total_start = time.time() # STEP 1: Load MIMIC-IV base dataset + print("\n[1/2] Loading MIMIC-IV base dataset...") + dataset_start = time.time() + + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + dev=dev, + cache_dir=f"{cache_root}/base_dataset", + ) + + dataset_time = time.time() - dataset_start + print(f"✓ Dataset loaded in {dataset_time:.2f} seconds") + + # STEP 2: Apply StageNet mortality prediction task with num_workers=4 + print("\n[2/2] Applying mortality prediction task (num_workers=4)...") + task_start = time.time() + + sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=4, + cache_dir=f"{cache_root}/task_samples", + ) + + task_time = time.time() - task_start + print(f"✓ Task processing completed in {task_time:.2f} seconds") + + # Measure cache sizes + print("\n[3/3] Measuring cache sizes...") + base_cache_dir = f"{cache_root}/base_dataset" + task_cache_dir = f"{cache_root}/task_samples" + + base_cache_size = get_directory_size(base_cache_dir) + task_cache_size = get_directory_size(task_cache_dir) + total_cache_size = base_cache_size + task_cache_size + + print(f"✓ Base dataset cache: {format_size(base_cache_size)}") + print(f"✓ Task samples cache: {format_size(task_cache_size)}") + print(f"✓ Total cache size: {format_size(total_cache_size)}") + + # Total time and peak memory + total_time = time.time() - total_start + peak_mem = PEAK_MEM_USAGE + + # Print summary + print("\n" + "=" * 80) + print("BENCHMARK RESULTS") + print("=" * 80) + print("Configuration:") + print(" - num_workers: 4") + print(f" - dev mode: {dev}") + print(f" - Total samples: {len(sample_dataset)}") + print("\nTiming:") + print(f" - Dataset loading: {dataset_time:.2f}s") + print(f" - Task processing: {task_time:.2f}s") + print(f" - Total time: {total_time:.2f}s") + print("\nCache Sizes:") + print(f" - Base dataset cache: {format_size(base_cache_size)}") + print(f" - Task samples cache: {format_size(task_cache_size)}") + print(f" - Total cache: {format_size(total_cache_size)}") + print("\nMemory:") + print(f" - Peak memory usage: {format_size(peak_mem)}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/benchmark_perf/memtest.py b/examples/benchmark_perf/memtest.py new file mode 100644 index 000000000..f8b92c71a --- /dev/null +++ b/examples/benchmark_perf/memtest.py @@ -0,0 +1,105 @@ +""" +Example of using StageNet for mortality prediction on MIMIC-IV. + +This example demonstrates: +1. Loading MIMIC-IV data +2. Applying the MortalityPredictionStageNetMIMIC4 task +3. Creating a SampleDataset with StageNet processors +4. Training a StageNet model +""" + +# %% +from pyhealth.datasets import ( + MIMIC4Dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models import StageNet +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 +from pyhealth.trainer import Trainer +import torch + +# %% STEP 1: Load MIMIC-IV base dataset +base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + dev=True, +) + +# %% # STEP 2: Apply StageNet mortality prediction task +sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=4, +) + +print(f"Total samples: {len(sample_dataset)}") +print(f"Input schema: {sample_dataset.input_schema}") +print(f"Output schema: {sample_dataset.output_schema}") + +# %% Inspect a sample +sample = next(iter(sample_dataset)) +print("\nSample structure:") +print(f" Patient ID: {sample['patient_id']}") +print(f"ICD Codes: {sample['icd_codes']}") +print(f" Labs shape: {len(sample['labs'][0])} timesteps") +print(f" Mortality: {sample['mortality']}") + +# %% STEP 3: Split dataset +train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] +) + +# Create dataloaders +train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) +val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) +test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) + +# %% STEP 4: Initialize StageNet model +model = StageNet( + dataset=sample_dataset, + embedding_dim=128, + chunk_size=128, + levels=3, + dropout=0.3, +) + +num_params = sum(p.numel() for p in model.parameters()) +print(f"\nModel initialized with {num_params} parameters") + +# %% STEP 5: Train the model +trainer = Trainer( + model=model, + device="cpu", # or "cpu" + metrics=["pr_auc", "roc_auc", "accuracy", "f1"], +) + +trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=5, + monitor="roc_auc", + optimizer_params={"lr": 1e-5}, +) + +# %% STEP 6: Evaluate on test set +results = trainer.evaluate(test_loader) +print("\nTest Results:") +for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + +# %% STEP 7: Inspect model predictions +sample_batch = next(iter(test_loader)) +with torch.no_grad(): + output = model(**sample_batch) + +print("\nSample predictions:") +print(f" Predicted probabilities: {output['y_prob'][:5]}") +print(f" True labels: {output['y_true'][:5]}") + +# %% From fe744dce67f82009a560f422232494dad0497deb Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 10 Dec 2025 14:48:36 -0500 Subject: [PATCH 56/82] add deps --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index bf45906a1..aa96d0e09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,8 +39,10 @@ dependencies = [ "pandas~=2.3.1", "pandarallel~=1.6.5", "pydantic~=2.11.7", + "dask[complete]~=2025.11.0", "litdata~=0.2.58", "pyarrow~=22.0.0", + "narwhals~=2.13.0", ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] From f6bc340bb76aacfa98f17d80d531ada9e51f4150 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 10 Dec 2025 15:03:19 -0500 Subject: [PATCH 57/82] add _scan_csv_tsv_gz --- pyhealth/datasets/base_dataset.py | 109 +++++++++++++++++++++++------- 1 file changed, 85 insertions(+), 24 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 1f171d46f..51ddacb08 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -14,10 +14,12 @@ import litdata from litdata.streaming.item_loader import ParquetLoader import pyarrow as pa +import pyarrow.csv as pv import pyarrow.parquet as pq import polars as pl import requests from tqdm import tqdm +import dask.dataframe as dd from ..data import Patient from ..tasks import BaseTask @@ -63,32 +65,24 @@ def path_exists(path: str) -> bool: else: return Path(path).exists() - -def scan_csv_gz_or_csv_tsv(path: str) -> pl.LazyFrame: +def _csv_tsv_gz_path(path: str) -> str: """ - Scan a CSV.gz, CSV, TSV.gz, or TSV file and returns a LazyFrame. - It will fall back to the other extension if not found. + Get the path to the file, trying the original path first, then the alternative path + by switching between .csv.gz, .csv, .tsv.gz, and .tsv extensions. Args: - path (str): URL or local path to a .csv, .csv.gz, .tsv, or .tsv.gz file + path (str): Original file path. Returns: - pl.LazyFrame: The LazyFrame for the CSV.gz, CSV, TSV.gz, or TSV file. - """ - - def scan_file(file_path: str) -> pl.LazyFrame: - separator = "\t" if ".tsv" in file_path else "," - return pl.scan_csv( - file_path, - separator=separator, - infer_schema=False, - low_memory=True, - ) + str: The file path that exists. + Raises: + FileNotFoundError: If neither the original nor the alternative path exists. + ValueError: If the path does not have an expected extension. + """ if path_exists(path): - return scan_file(path) + return path - # Try the alternative extension if path.endswith(".csv.gz"): alt_path = path[:-3] # Remove .gz -> try .csv elif path.endswith(".csv"): @@ -98,15 +92,13 @@ def scan_file(file_path: str) -> pl.LazyFrame: elif path.endswith(".tsv"): alt_path = f"{path}.gz" # Add .gz -> try .tsv.gz else: - raise FileNotFoundError(f"Path does not have expected extension: {path}") - + raise ValueError(f"Path does not have expected extension: {path}") + if path_exists(alt_path): - logger.info(f"Original path does not exist. Using alternative: {alt_path}") - return scan_file(alt_path) - + return alt_path + raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") - def _uncollate(x: list[Any]) -> Any: return x[0] if isinstance(x, list) and len(x) == 1 else x @@ -264,6 +256,75 @@ def cache_dir(self) -> Path: cache_dir.mkdir(parents=True, exist_ok=True) self._cache_dir = cache_dir return Path(self._cache_dir) + + @property + def temp_dir(self) -> Path: + return self.cache_dir / "temp" + + def _scan_csv_tsv_gz(self, table_name: str, source_path: str | None = None) -> dd.DataFrame: + """Scans a CSV/TSV file (possibly gzipped) and returns a Dask DataFrame. + + If the cached Parquet file does not exist, it converts the source CSV/TSV file + to Parquet and saves it to the cache. + + Args: + table_name (str): The name of the table. + source_path (str | None): The source CSV/TSV file path. If None, assumes the + Parquet file already exists in the cache. + + Returns: + dd.DataFrame: The Dask DataFrame loaded from the cached Parquet file. + + Raises: + FileNotFoundError: If source_path is None and the cached Parquet file does not exist; + or if neither the original nor the alternative path of source_path exists. + ValueError: If the path does not have an expected extension. + """ + # Ensure the tables cache directory exists + (self.temp_dir / "tables").mkdir(parents=True, exist_ok=True) + ret_path = str(self.temp_dir / "tables" / f"{table_name}.parquet") + + if not path_exists(ret_path): + if source_path is None: + raise FileNotFoundError( + f"Table {table_name} not found in cache and no source_path provided." + ) + + source_path = _csv_tsv_gz_path(source_path) + + # Determine delimiter based on file extension + delimiter = ( + "\t" + if source_path.endswith(".tsv") or source_path.endswith(".tsv.gz") + else "," + ) + + # Always infer schema as string to avoid incorrect type inference + schema_reader = pv.open_csv( + source_path, + read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB + parse_options=pv.ParseOptions(delimiter=delimiter), + ) + schema = pa.schema( + [pa.field(name, pa.string()) for name in schema_reader.schema.names] + ) + + # Convert CSV/TSV to Parquet + csv_reader = pv.open_csv( + source_path, + read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB + parse_options=pv.ParseOptions(delimiter=delimiter), + convert_options=pv.ConvertOptions(column_types=schema), + ) + with pq.ParquetWriter(ret_path, csv_reader.schema) as writer: + for batch in csv_reader: + writer.write_batch(batch) + + return dd.read_parquet( + ret_path, + split_row_groups=True, # type: ignore + blocksize="64MB", + ) @property def global_event_df(self) -> pl.LazyFrame: From 4e236664b762556678364b04cd4d8bec51273851 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 10 Dec 2025 15:15:41 -0500 Subject: [PATCH 58/82] convert load_table to dask --- pyhealth/datasets/base_dataset.py | 78 +++++++++++++++++-------------- 1 file changed, 43 insertions(+), 35 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 51ddacb08..58fa6ebd4 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -4,7 +4,9 @@ from abc import ABC from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from typing import Dict, Iterator, List, Optional, Any +from typing import Dict, Iterator, List, Optional, Any, Callable +import functools +import operator from urllib.parse import urlparse, urlunparse import json import uuid @@ -16,6 +18,7 @@ import pyarrow as pa import pyarrow.csv as pv import pyarrow.parquet as pq +import pandas as pd import polars as pl import requests from tqdm import tqdm @@ -28,6 +31,8 @@ from .sample_dataset import SampleDataset, SampleBuilder from .utils import _convert_for_cache, _restore_from_cache +# Set logging level for distributed to ERROR to reduce verbosity +logging.getLogger("distributed").setLevel(logging.ERROR) logger = logging.getLogger(__name__) @@ -371,14 +376,14 @@ def load_data(self) -> pl.LazyFrame: frames = [self.load_table(table.lower()) for table in self.tables] return pl.concat(frames, how="diagonal") - def load_table(self, table_name: str) -> pl.LazyFrame: + def load_table(self, table_name: str) -> dd.DataFrame: """Loads a table and processes joins if specified. Args: table_name (str): The name of the table to load. Returns: - pl.LazyFrame: The processed lazy frame for the table. + dd.DataFrame: The processed Dask dataframe for the table. Raises: ValueError: If the table is not found in the config. @@ -394,12 +399,13 @@ def load_table(self, table_name: str) -> pl.LazyFrame: csv_path = clean_path(csv_path) logger.info(f"Scanning table: {table_name} from {csv_path}") - df = scan_csv_gz_or_csv_tsv(csv_path) + df = self._scan_csv_tsv_gz(table_name, csv_path) # Convert column names to lowercase before calling preprocess_func - df = df.rename(lambda col: col.lower()) + df = df.rename(columns=str.lower) # Check if there is a preprocessing function for this table + preprocess_func: Optional[Callable[[dd.DataFrame], dd.DataFrame]] preprocess_func = getattr(self, f"preprocess_{table_name}", None) if preprocess_func is not None: logger.info( @@ -412,13 +418,13 @@ def load_table(self, table_name: str) -> pl.LazyFrame: other_csv_path = f"{self.root}/{join_cfg.file_path}" other_csv_path = clean_path(other_csv_path) logger.info(f"Joining with table: {other_csv_path}") - join_df = scan_csv_gz_or_csv_tsv(other_csv_path) - join_df = join_df.rename(lambda col: col.lower()) + join_df = self._scan_csv_tsv_gz(table_name, other_csv_path) + join_df = join_df.rename(columns=str.lower) join_key = join_cfg.on columns = join_cfg.columns how = join_cfg.how - df = df.join(join_df.select([join_key] + columns), on=join_key, how=how) # type: ignore + df: dd.DataFrame = df.merge(join_df[[join_key] + columns], on=join_key, how=how) patient_id_col = table_cfg.patient_id timestamp_col = table_cfg.timestamp @@ -429,38 +435,40 @@ def load_table(self, table_name: str) -> pl.LazyFrame: if timestamp_col: if isinstance(timestamp_col, list): # Concatenate all timestamp parts in order with no separator - combined_timestamp = pl.concat_str( - [pl.col(col) for col in timestamp_col] - ).str.strptime(pl.Datetime, format=timestamp_format, strict=True) - timestamp_expr = combined_timestamp + timestamp_series: dd.Series = functools.reduce( + operator.add, (df[col].astype(str) for col in timestamp_col) + ) else: # Single timestamp column - timestamp_expr = pl.col(timestamp_col).str.strptime( - pl.Datetime, format=timestamp_format, strict=True - ) + timestamp_series: dd.Series = df[timestamp_col].astype(str) + timestamp_series: dd.Series = dd.to_datetime( + timestamp_series, + format=timestamp_format, + errors="raise", + ) + df: dd.DataFrame = df.assign( + timestamp=timestamp_series.astype("datetime64[ms]") + ) else: - timestamp_expr = pl.lit(None, dtype=pl.Datetime) + df: dd.DataFrame = df.assign(timestamp=pd.NaT) + # If patient_id_col is None, use row index as patient_id - patient_id_expr = ( - pl.col(patient_id_col).cast(pl.Utf8) - if patient_id_col - else pl.int_range(0, pl.count()).cast(pl.Utf8) - ) - base_columns = [ - patient_id_expr.alias("patient_id"), - pl.lit(table_name).cast(pl.Utf8).alias("event_type"), - # ms should be sufficient for most cases - timestamp_expr.cast(pl.Datetime(time_unit="ms")).alias("timestamp"), - ] - - # Flatten attribute columns with event_type prefix - attribute_columns = [ - pl.col(attr.lower()).alias(f"{table_name}/{attr}") - for attr in attribute_cols - ] - - event_frame = df.select(base_columns + attribute_columns) + if patient_id_col: + df: dd.DataFrame = df.assign(patient_id=df[patient_id_col].astype(str)) + else: + df: dd.DataFrame = df.reset_index(drop=True) + df: dd.DataFrame = df.assign(patient_id=df.index.astype(str)) + + + df: dd.DataFrame = df.assign(event_type=table_name) + + rename_attr = {attr.lower(): f"{table_name}/{attr}" for attr in attribute_cols} + df: dd.DataFrame = df.rename(columns=rename_attr) + + attr_cols = [rename_attr[attr.lower()] for attr in attribute_cols] + final_cols = ["patient_id", "event_type", "timestamp"] + attr_cols + event_frame = df[final_cols] return event_frame From ecec90da0b6d260cd8ea9f78fcd9d719148b64a3 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 10 Dec 2025 15:16:14 -0500 Subject: [PATCH 59/82] convert load_data to dask --- pyhealth/datasets/base_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 58fa6ebd4..9c4ff7086 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -367,14 +367,14 @@ def global_event_df(self) -> pl.LazyFrame: "patient_id" ) # Guarantee sorted read, see sink_parquet above - def load_data(self) -> pl.LazyFrame: + def load_data(self) -> dd.DataFrame: """Loads data from the specified tables. Returns: - pl.LazyFrame: A concatenated lazy frame of all tables. + dd.DataFrame: A concatenated lazy frame of all tables. """ frames = [self.load_table(table.lower()) for table in self.tables] - return pl.concat(frames, how="diagonal") + return dd.concat(frames, axis=0, join="outer") def load_table(self, table_name: str) -> dd.DataFrame: """Loads a table and processes joins if specified. From ff0b5023ab06634ec0491521a652232ef0a543f4 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 10 Dec 2025 15:52:45 -0500 Subject: [PATCH 60/82] convert global_event_df to dask --- pyhealth/datasets/base_dataset.py | 48 +++++++++++++++++-------------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 9c4ff7086..ee4b44d94 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -23,13 +23,13 @@ import requests from tqdm import tqdm import dask.dataframe as dd +from dask.distributed import Client, LocalCluster, progress from ..data import Patient from ..tasks import BaseTask from ..processors.base_processor import FeatureProcessor from .configs import load_yaml_config from .sample_dataset import SampleDataset, SampleBuilder -from .utils import _convert_for_cache, _restore_from_cache # Set logging level for distributed to ERROR to reduce verbosity logging.getLogger("distributed").setLevel(logging.ERROR) @@ -339,26 +339,32 @@ def global_event_df(self) -> pl.LazyFrame: Path: The path to the cached event dataframe. """ if self._global_event_df is None: - path = self.cache_dir / "global_event_df.parquet" - if not path.exists(): - df = self.load_data() - if self.dev: - logger.info("Dev mode enabled: limiting to 1000 patients") - limited_patients = ( - df.select(pl.col("patient_id").shuffle(seed=0)) - .unique() - .limit(1000) - ) - df = df.join(limited_patients, on="patient_id", how="inner") - - logger.info(f"Caching event dataframe to {path}...") - df.sort("patient_id").sink_parquet( - path, - compression="lz4", # use lz4 compression for faster read/write - row_group_size=8_192, - maintain_order=True, # Important for sorted writes - ) - self._global_event_df = path + ret_path = self.cache_dir / "global_event_df.parquet" + if not ret_path.exists(): + with LocalCluster( + n_workers=4, # TODO: make this configurable + threads_per_worker=1, + memory_limit="8GB", # TODO: make this configurable + ) as cluster: + with Client(cluster) as client: + df: dd.DataFrame = self.load_data() + if self.dev: + logger.info("Dev mode enabled: limiting to 1000 patients") + patients = ( + df["patient_id"].unique().head(1000).tolist() + ) + filter = df["patient_id"].isin(patients) + df = df[filter] + + logger.info(f"Caching event dataframe to {ret_path}...") + collection = df.sort_values("patient_id").to_parquet( + ret_path, + write_index=False, + compute=False, + ) + handle = client.compute(collection) + progress(handle) + self._global_event_df = ret_path return pl.scan_parquet( self._global_event_df, From 383026808de2c1c324507739e0dfda81b46b903e Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 10 Dec 2025 15:58:49 -0500 Subject: [PATCH 61/82] Fix base dataset test --- tests/core/test_base_dataset.py | 60 +++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/tests/core/test_base_dataset.py b/tests/core/test_base_dataset.py index 2efd49d46..7e69e7c84 100644 --- a/tests/core/test_base_dataset.py +++ b/tests/core/test_base_dataset.py @@ -3,30 +3,35 @@ from unittest.mock import patch import polars as pl +import pandas as pd +import dask.dataframe as dd from pyhealth.datasets.base_dataset import BaseDataset -class InMemoryDataset(BaseDataset): +class MockDataset(BaseDataset): """Dataset that bypasses file loading for tests.""" - def __init__(self, data: pl.DataFrame, **kwargs): + def __init__(self, data: dd.DataFrame, **kwargs): self._data = data super().__init__(**kwargs) - def load_data(self) -> pl.LazyFrame: - return self._data.lazy() + def load_data(self) -> dd.DataFrame: + return self._data class TestBaseDataset(unittest.TestCase): - def _single_row_data(self) -> pl.DataFrame: - return pl.DataFrame( - { - "patient_id": ["1"], - "event_type": ["test"], - "timestamp": [None], - "test/value": [0], - } + def _single_row_data(self) -> dd.DataFrame: + return dd.from_pandas( + pd.DataFrame( + { + "patient_id": ["1"], + "event_type": ["test"], + "timestamp": [None], + "test/value": [0], + } + ), + npartitions=1, ) def test_cache_dir_varies_with_core_identifiers(self): @@ -41,31 +46,31 @@ def test_cache_dir_varies_with_core_identifiers(self): return_value=cache_root, ): datasets = [ - InMemoryDataset( + MockDataset( data=self._single_row_data(), root="/data/root_a", **base_kwargs, ), - InMemoryDataset( + MockDataset( data=self._single_row_data(), root="/data/root_b", # different root **base_kwargs, ), - InMemoryDataset( + MockDataset( data=self._single_row_data(), root="/data/root_a", tables=["table_b"], # different tables dataset_name="CacheDataset", dev=False, ), - InMemoryDataset( + MockDataset( data=self._single_row_data(), root="/data/root_a", tables=["table_a"], dataset_name="OtherDataset", # different dataset name dev=False, ), - InMemoryDataset( + MockDataset( data=self._single_row_data(), root="/data/root_a", tables=["table_a"], @@ -82,21 +87,24 @@ def test_cache_dir_varies_with_core_identifiers(self): ) def test_event_df_cache_is_physically_sorted(self): - unsorted_data = pl.DataFrame( - { - "patient_id": ["3", "1", "2", "1"], - "event_type": ["test"] * 4, - "timestamp": [None] * 4, - "test/value": [10, 20, 30, 40], - } + unsorted_data = dd.from_pandas( + pd.DataFrame( + { + "patient_id": ["3", "1", "2", "1"], + "event_type": ["test"] * 4, + "timestamp": [None] * 4, + "test/value": [10, 20, 30, 40], + } + ), + npartitions=1, ) - original_order = unsorted_data["patient_id"].to_list() + original_order = unsorted_data["patient_id"].compute().tolist() with tempfile.TemporaryDirectory() as cache_root, patch( "pyhealth.datasets.base_dataset.platformdirs.user_cache_dir", return_value=cache_root, ): - dataset = InMemoryDataset( + dataset = MockDataset( data=unsorted_data, root="/data/root_sort", tables=["table_a"], From 73aa3a8a97e9ee42b103d3427a6564b4f26eb988 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 10 Dec 2025 17:13:55 -0500 Subject: [PATCH 62/82] Fix bug --- pyhealth/datasets/base_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index ee4b44d94..bce3d94f2 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -420,11 +420,11 @@ def load_table(self, table_name: str) -> dd.DataFrame: df = preprocess_func(df) # Handle joins - for join_cfg in table_cfg.join: + for i, join_cfg in enumerate(table_cfg.join): other_csv_path = f"{self.root}/{join_cfg.file_path}" other_csv_path = clean_path(other_csv_path) logger.info(f"Joining with table: {other_csv_path}") - join_df = self._scan_csv_tsv_gz(table_name, other_csv_path) + join_df = self._scan_csv_tsv_gz(f"{table_name}_join_{i}", other_csv_path) join_df = join_df.rename(columns=str.lower) join_key = join_cfg.on columns = join_cfg.columns From 668f4dec5799bd0c1e505f7106552865002287d4 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 10 Dec 2025 18:12:20 -0500 Subject: [PATCH 63/82] Fixup --- pyhealth/datasets/base_dataset.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index bce3d94f2..10a56ba4c 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -342,9 +342,9 @@ def global_event_df(self) -> pl.LazyFrame: ret_path = self.cache_dir / "global_event_df.parquet" if not ret_path.exists(): with LocalCluster( - n_workers=4, # TODO: make this configurable + n_workers=4, threads_per_worker=1, - memory_limit="8GB", # TODO: make this configurable + processes=False, ) as cluster: with Client(cluster) as client: df: dd.DataFrame = self.load_data() @@ -452,9 +452,7 @@ def load_table(self, table_name: str) -> dd.DataFrame: format=timestamp_format, errors="raise", ) - df: dd.DataFrame = df.assign( - timestamp=timestamp_series.astype("datetime64[ms]") - ) + df: dd.DataFrame = df.assign(timestamp=timestamp_series) else: df: dd.DataFrame = df.assign(timestamp=pd.NaT) From ebcabc5c1375fdb167e0e020cf94d9f9ae37e985 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 10 Dec 2025 18:42:22 -0500 Subject: [PATCH 64/82] Fixup --- pyhealth/datasets/base_dataset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 10a56ba4c..63317842a 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -369,9 +369,7 @@ def global_event_df(self) -> pl.LazyFrame: return pl.scan_parquet( self._global_event_df, low_memory=True, - ).set_sorted( - "patient_id" - ) # Guarantee sorted read, see sink_parquet above + ) def load_data(self) -> dd.DataFrame: """Loads data from the specified tables. From d8e23836e714e00222a9cf5b4663da9dcac2c89d Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 10 Dec 2025 18:47:03 -0500 Subject: [PATCH 65/82] fixup --- pyhealth/data/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index 2a6d3a45c..58c99e56c 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -152,9 +152,9 @@ def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime] start_idx = 0 end_idx = len(ts_col) if start is not None: - start_idx = np.searchsorted(ts_col, start, side="left") + start_idx = np.searchsorted(ts_col, np.datetime64(start), side="left") if end is not None: - end_idx = np.searchsorted(ts_col, end, side="right") + end_idx = np.searchsorted(ts_col, np.datetime64(end), side="right") return df.slice(start_idx, end_idx - start_idx) def _filter_by_event_type_regular(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: From aa64e99c5f0c7e51e4ca4ab9284e32b3b36a5839 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 10 Dec 2025 19:18:30 -0500 Subject: [PATCH 66/82] main guard --- examples/memtest.py | 189 ++++++++++++++++++++++---------------------- 1 file changed, 95 insertions(+), 94 deletions(-) diff --git a/examples/memtest.py b/examples/memtest.py index 6eb55ede0..415992737 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -9,97 +9,98 @@ """ # %% -from pyhealth.datasets import ( - MIMIC4Dataset, - get_dataloader, - split_by_patient, -) -from pyhealth.models import StageNet -from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 -from pyhealth.trainer import Trainer -import torch - -# %% STEP 1: Load MIMIC-IV base dataset -base_dataset = MIMIC4Dataset( - ehr_root="/home/logic/physionet.org/files/mimiciv/3.1/", - ehr_tables=[ - "patients", - "admissions", - "diagnoses_icd", - "procedures_icd", - "labevents", - ], - dev=True, -) - -# %% # STEP 2: Apply StageNet mortality prediction task -sample_dataset = base_dataset.set_task( - MortalityPredictionStageNetMIMIC4(), - num_workers=4, -) - -print(f"Total samples: {len(sample_dataset)}") -print(f"Input schema: {sample_dataset.input_schema}") -print(f"Output schema: {sample_dataset.output_schema}") - -# %% Inspect a sample -sample = next(iter(sample_dataset)) -print("\nSample structure:") -print(f" Patient ID: {sample['patient_id']}") -print(f"ICD Codes: {sample['icd_codes']}") -print(f" Labs shape: {len(sample['labs'][0])} timesteps") -print(f" Mortality: {sample['mortality']}") - -# %% STEP 3: Split dataset -train_dataset, val_dataset, test_dataset = split_by_patient( - sample_dataset, [0.8, 0.1, 0.1] -) - -# Create dataloaders -train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) -val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) -test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) - -# %% STEP 4: Initialize StageNet model -model = StageNet( - dataset=sample_dataset, - embedding_dim=128, - chunk_size=128, - levels=3, - dropout=0.3, -) - -num_params = sum(p.numel() for p in model.parameters()) -print(f"\nModel initialized with {num_params} parameters") - -# %% STEP 5: Train the model -trainer = Trainer( - model=model, - device="cpu", # or "cpu" - metrics=["pr_auc", "roc_auc", "accuracy", "f1"], -) - -trainer.train( - train_dataloader=train_loader, - val_dataloader=val_loader, - epochs=5, - monitor="roc_auc", - optimizer_params={"lr": 1e-5}, -) - -# %% STEP 6: Evaluate on test set -results = trainer.evaluate(test_loader) -print("\nTest Results:") -for metric, value in results.items(): - print(f" {metric}: {value:.4f}") - -# %% STEP 7: Inspect model predictions -sample_batch = next(iter(test_loader)) -with torch.no_grad(): - output = model(**sample_batch) - -print("\nSample predictions:") -print(f" Predicted probabilities: {output['y_prob'][:5]}") -print(f" True labels: {output['y_true'][:5]}") - -# %% +if __name__ == "__main__": + from pyhealth.datasets import ( + MIMIC4Dataset, + get_dataloader, + split_by_patient, + ) + from pyhealth.models import StageNet + from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 + from pyhealth.trainer import Trainer + import torch + + # %% STEP 1: Load MIMIC-IV base dataset + base_dataset = MIMIC4Dataset( + ehr_root="/home/logic/physionet.org/files/mimiciv/3.1/", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + dev=True, + ) + + # %% # STEP 2: Apply StageNet mortality prediction task + sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=4, + ) + + print(f"Total samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + # %% Inspect a sample + sample = next(iter(sample_dataset)) + print("\nSample structure:") + print(f" Patient ID: {sample['patient_id']}") + print(f"ICD Codes: {sample['icd_codes']}") + print(f" Labs shape: {len(sample['labs'][0])} timesteps") + print(f" Mortality: {sample['mortality']}") + + # %% STEP 3: Split dataset + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + + # Create dataloaders + train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) + + # %% STEP 4: Initialize StageNet model + model = StageNet( + dataset=sample_dataset, + embedding_dim=128, + chunk_size=128, + levels=3, + dropout=0.3, + ) + + num_params = sum(p.numel() for p in model.parameters()) + print(f"\nModel initialized with {num_params} parameters") + + # %% STEP 5: Train the model + trainer = Trainer( + model=model, + device="cpu", # or "cpu" + metrics=["pr_auc", "roc_auc", "accuracy", "f1"], + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=5, + monitor="roc_auc", + optimizer_params={"lr": 1e-5}, + ) + + # %% STEP 6: Evaluate on test set + results = trainer.evaluate(test_loader) + print("\nTest Results:") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + # %% STEP 7: Inspect model predictions + sample_batch = next(iter(test_loader)) + with torch.no_grad(): + output = model(**sample_batch) + + print("\nSample predictions:") + print(f" Predicted probabilities: {output['y_prob'][:5]}") + print(f" True labels: {output['y_true'][:5]}") + + # %% From 763f358e6d083e716978d3eeaad65ebcaf815765 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 10 Dec 2025 19:18:48 -0500 Subject: [PATCH 67/82] fix incorrect null value handling --- pyhealth/tasks/mortality_prediction_stagenet_mimic4.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py index 91e1f94cd..e7061a6e6 100644 --- a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py +++ b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py @@ -226,7 +226,8 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: labevents_df = labevents_df.select( pl.col("timestamp"), pl.col("labevents/itemid"), - pl.col("labevents/valuenum").cast(pl.Float64), + # There are potential empty strings in valuenum, which should be cast to nulls + pl.col("labevents/valuenum").replace("", None).cast(pl.Float64), ) # Group by timestamp and aggregate into 10D vectors From bd92ba7494909618ae957531bdc14eefa3226982 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 10 Dec 2025 19:41:10 -0500 Subject: [PATCH 68/82] change back to ms to mimic old pyhealth beahviour --- pyhealth/data/data.py | 4 ++-- pyhealth/datasets/base_dataset.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index 58c99e56c..14b1b526c 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -152,9 +152,9 @@ def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime] start_idx = 0 end_idx = len(ts_col) if start is not None: - start_idx = np.searchsorted(ts_col, np.datetime64(start), side="left") + start_idx = np.searchsorted(ts_col, np.datetime64(start, "ms"), side="left") if end is not None: - end_idx = np.searchsorted(ts_col, np.datetime64(end), side="right") + end_idx = np.searchsorted(ts_col, np.datetime64(end, "ms"), side="right") return df.slice(start_idx, end_idx - start_idx) def _filter_by_event_type_regular(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 63317842a..8e48113a8 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -450,7 +450,7 @@ def load_table(self, table_name: str) -> dd.DataFrame: format=timestamp_format, errors="raise", ) - df: dd.DataFrame = df.assign(timestamp=timestamp_series) + df: dd.DataFrame = df.assign(timestamp=timestamp_series.astype("datetime64[ms]")) else: df: dd.DataFrame = df.assign(timestamp=pd.NaT) From 466d95e6be4f553659dee361e74d6c46d722ab26 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 10 Dec 2025 20:11:40 -0500 Subject: [PATCH 69/82] add TODO --- pyhealth/datasets/base_dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 8e48113a8..61f3401c8 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -341,6 +341,9 @@ def global_event_df(self) -> pl.LazyFrame: if self._global_event_df is None: ret_path = self.cache_dir / "global_event_df.parquet" if not ret_path.exists(): + # TODO: auto select processes=True/False based on if it's in jupyter notebook + # The processes=True will crash in jupyter notebook. + # TODO: make the n_workers configurable with LocalCluster( n_workers=4, threads_per_worker=1, From 87b171f7733077a481b14ff99aa3bdeb18a79de7 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 12 Dec 2025 20:46:27 -0500 Subject: [PATCH 70/82] main guard check --- pyhealth/datasets/base_dataset.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 61f3401c8..f9b5d8f96 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -12,6 +12,7 @@ import uuid import platformdirs import tempfile +import multiprocessing import litdata from litdata.streaming.item_loader import ParquetLoader @@ -338,6 +339,13 @@ def global_event_df(self) -> pl.LazyFrame: Returns: Path: The path to the cached event dataframe. """ + if not multiprocessing.current_process().name == "MainProcess": + logger.warning( + "global_event_df property accessed from a non-main process. This may lead to unexpected behavior.\n" + + "Consider use __name__ == '__main__' guard when using multiprocessing." + ) + return None # type: ignore + if self._global_event_df is None: ret_path = self.cache_dir / "global_event_df.parquet" if not ret_path.exists(): @@ -347,7 +355,7 @@ def global_event_df(self) -> pl.LazyFrame: with LocalCluster( n_workers=4, threads_per_worker=1, - processes=False, + # processes=False, ) as cluster: with Client(cluster) as client: df: dd.DataFrame = self.load_data() @@ -589,6 +597,13 @@ def set_task( Raises: AssertionError: If no default task is found and task is None. """ + if not multiprocessing.current_process().name == "MainProcess": + logger.warning( + "set_task method accessed from a non-main process. This may lead to unexpected behavior.\n" + + "Consider use __name__ == '__main__' guard when using multiprocessing." + ) + return None # type: ignore + if task is None: assert self.default_task is not None, "No default tasks found" task = self.default_task From e77e08b58b9f59e7c9dd767e6024ff7cd532b82e Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 12 Dec 2025 20:49:56 -0500 Subject: [PATCH 71/82] fix nullable issue? --- pyhealth/datasets/base_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index f9b5d8f96..27c38a3db 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -326,11 +326,12 @@ def _scan_csv_tsv_gz(self, table_name: str, source_path: str | None = None) -> d for batch in csv_reader: writer.write_batch(batch) - return dd.read_parquet( + df: dd.DataFrame = dd.read_parquet( ret_path, split_row_groups=True, # type: ignore blocksize="64MB", ) + return df.replace("", pd.NA) # Replace empty strings with NaN @property def global_event_df(self) -> pl.LazyFrame: From 49b3b64a5cc4be129f2592ca4924a42e9ea1ef76 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 12 Dec 2025 20:53:23 -0500 Subject: [PATCH 72/82] revert change --- pyhealth/tasks/mortality_prediction_stagenet_mimic4.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py index e7061a6e6..91e1f94cd 100644 --- a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py +++ b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py @@ -226,8 +226,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: labevents_df = labevents_df.select( pl.col("timestamp"), pl.col("labevents/itemid"), - # There are potential empty strings in valuenum, which should be cast to nulls - pl.col("labevents/valuenum").replace("", None).cast(pl.Float64), + pl.col("labevents/valuenum").cast(pl.Float64), ) # Group by timestamp and aggregate into 10D vectors From 77483072aee18fce1071e5c034159c77d95c7034 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 14 Dec 2025 14:38:00 -0500 Subject: [PATCH 73/82] update API layers --- pyhealth/datasets/base_dataset.py | 5 +++-- pyhealth/datasets/bmd_hs.py | 2 +- pyhealth/datasets/mimic3.py | 2 +- pyhealth/datasets/omop.py | 8 ++++---- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 27c38a3db..157e220e0 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -25,6 +25,7 @@ from tqdm import tqdm import dask.dataframe as dd from dask.distributed import Client, LocalCluster, progress +import narwhals as nw from ..data import Patient from ..tasks import BaseTask @@ -421,13 +422,13 @@ def load_table(self, table_name: str) -> dd.DataFrame: df = df.rename(columns=str.lower) # Check if there is a preprocessing function for this table - preprocess_func: Optional[Callable[[dd.DataFrame], dd.DataFrame]] + preprocess_func: Optional[Callable[[nw.LazyFrame], nw.LazyFrame]] preprocess_func = getattr(self, f"preprocess_{table_name}", None) if preprocess_func is not None: logger.info( f"Preprocessing table: {table_name} with {preprocess_func.__name__}" ) - df = preprocess_func(df) + df = preprocess_func(nw.from_native(df)).to_native() # type: ignore # Handle joins for i, join_cfg in enumerate(table_cfg.join): diff --git a/pyhealth/datasets/bmd_hs.py b/pyhealth/datasets/bmd_hs.py index 183c23766..30c706132 100644 --- a/pyhealth/datasets/bmd_hs.py +++ b/pyhealth/datasets/bmd_hs.py @@ -1,7 +1,7 @@ import logging from pathlib import Path from typing import Optional -import polars as pl +import narwhals as pl from pyhealth.tasks.base_task import BaseTask from pyhealth.tasks.bmd_hs_disease_classification import BMDHSDiseaseClassification diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py index eac1f33b8..22ca79d5c 100644 --- a/pyhealth/datasets/mimic3.py +++ b/pyhealth/datasets/mimic3.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import List, Optional -import polars as pl +import narwhals as pl from .base_dataset import BaseDataset diff --git a/pyhealth/datasets/omop.py b/pyhealth/datasets/omop.py index 675210c14..ef688a035 100644 --- a/pyhealth/datasets/omop.py +++ b/pyhealth/datasets/omop.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import List, Optional -import polars as pl +import narwhals as pl from .base_dataset import BaseDataset @@ -155,15 +155,15 @@ def preprocess_person(self, df: pl.LazyFrame) -> pl.LazyFrame: df = df.with_columns( [ ( - pl.col("year_of_birth").cast(pl.Utf8) + pl.col("year_of_birth").cast(pl.String) + "-" + pl.when(pl.col("month_of_birth").is_null()) .then(pl.lit("01")) - .otherwise(pl.col("month_of_birth").cast(pl.Utf8).str.zfill(2)) + .otherwise(pl.col("month_of_birth").cast(pl.String).str.zfill(2)) + "-" + pl.when(pl.col("day_of_birth").is_null()) .then(pl.lit("01")) - .otherwise(pl.col("day_of_birth").cast(pl.Utf8).str.zfill(2)) + .otherwise(pl.col("day_of_birth").cast(pl.String).str.zfill(2)) + " 00:00:00" ).alias("birth_datetime") ] From f6a582b0f726179b74746735bdedf0aff3041e06 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 14 Dec 2025 14:52:05 -0500 Subject: [PATCH 74/82] Fix cache test --- tests/core/test_caching.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index 2ae9bfe3e..5bdab5640 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -4,6 +4,7 @@ from pathlib import Path from unittest.mock import patch import polars as pl +import dask.dataframe as dd import torch from tests.base import BaseTestCase @@ -51,20 +52,25 @@ def __init__(self, cache_dir: str | Path | None = None): dev=False, ) - def load_data(self) -> pl.LazyFrame: - return pl.LazyFrame( - { - "patient_id": ["1", "2", "1", "2"], - "event_type": ["test", "test", "test", "test"], - "timestamp": [None, None, None, None], - "test/test_attribute": [ - "pat_1_attr_1", - "pat_2_attr_1", - "pat_1_attr_2", - "pat_2_attr_2", - ], - "test/test_label": [0, 1, 1, 0], - } + def load_data(self) -> dd.DataFrame: + import pandas as pd + + return dd.from_pandas( + pd.DataFrame( + { + "patient_id": ["1", "2", "1", "2"], + "event_type": ["test", "test", "test", "test"], + "timestamp": [None, None, None, None], + "test/test_attribute": [ + "pat_1_attr_1", + "pat_2_attr_1", + "pat_1_attr_2", + "pat_2_attr_2", + ], + "test/test_label": [0, 1, 1, 0], + } + ), + npartitions=1, ) From 7c7e5e43ec677daf8fe3b721d20445699fe804cf Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 14 Dec 2025 15:26:38 -0500 Subject: [PATCH 75/82] support remote url --- pyhealth/datasets/base_dataset.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 157e220e0..d4f08a7e8 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -8,6 +8,7 @@ import functools import operator from urllib.parse import urlparse, urlunparse +from urllib.request import urlretrieve import json import uuid import platformdirs @@ -299,6 +300,16 @@ def _scan_csv_tsv_gz(self, table_name: str, source_path: str | None = None) -> d source_path = _csv_tsv_gz_path(source_path) + if is_url(source_path): + local_filename = os.path.basename(source_path) + download_dir = self.temp_dir / "downloads" + download_dir.mkdir(parents=True, exist_ok=True) + local_path = download_dir / local_filename + if not local_path.exists(): + logger.info(f"Downloading {source_path} to {local_path}") + urlretrieve(source_path, local_path) + source_path = str(local_path) + # Determine delimiter based on file extension delimiter = ( "\t" From cfa87be3bd9ab91a80c6572b309b597b684dbd67 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 14 Dec 2025 15:26:43 -0500 Subject: [PATCH 76/82] fix test --- tests/core/test_processor_schemas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/core/test_processor_schemas.py b/tests/core/test_processor_schemas.py index 0651297e2..a46820884 100644 --- a/tests/core/test_processor_schemas.py +++ b/tests/core/test_processor_schemas.py @@ -45,6 +45,7 @@ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: df.filter(pl.col("event_type") == "noteevents") .select("patient_id") .unique() + .collect() .to_series() ) ) From 264a02655525478ddad892db3dd45b4d9047b0d2 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 14 Dec 2025 15:31:02 -0500 Subject: [PATCH 77/82] fix incorrect type in nw --- pyhealth/datasets/omop.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pyhealth/datasets/omop.py b/pyhealth/datasets/omop.py index ef688a035..bb9b0f1ed 100644 --- a/pyhealth/datasets/omop.py +++ b/pyhealth/datasets/omop.py @@ -157,13 +157,17 @@ def preprocess_person(self, df: pl.LazyFrame) -> pl.LazyFrame: ( pl.col("year_of_birth").cast(pl.String) + "-" - + pl.when(pl.col("month_of_birth").is_null()) - .then(pl.lit("01")) - .otherwise(pl.col("month_of_birth").cast(pl.String).str.zfill(2)) + + ( + pl.when(pl.col("month_of_birth").is_null()) + .then(pl.lit("01")) + .otherwise(pl.col("month_of_birth").cast(pl.String).str.zfill(2)) + ).cast(pl.String) + "-" - + pl.when(pl.col("day_of_birth").is_null()) - .then(pl.lit("01")) - .otherwise(pl.col("day_of_birth").cast(pl.String).str.zfill(2)) + + ( + pl.when(pl.col("day_of_birth").is_null()) + .then(pl.lit("01")) + .otherwise(pl.col("day_of_birth").cast(pl.String).str.zfill(2)) + ).cast(pl.String) + " 00:00:00" ).alias("birth_datetime") ] From 045d2ad9f72019517a7d4aa4934c578320f55088 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Mon, 15 Dec 2025 12:31:02 -0500 Subject: [PATCH 78/82] less worker --- pyhealth/datasets/base_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index d4f08a7e8..f6ba72a85 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -366,9 +366,9 @@ def global_event_df(self) -> pl.LazyFrame: # The processes=True will crash in jupyter notebook. # TODO: make the n_workers configurable with LocalCluster( - n_workers=4, + n_workers=1, threads_per_worker=1, - # processes=False, + processes=False, ) as cluster: with Client(cluster) as client: df: dd.DataFrame = self.load_data() From 5b000b18ef5df85a756ff47784c3dfccb8704e58 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Mon, 15 Dec 2025 12:35:21 -0500 Subject: [PATCH 79/82] non-dev for memtest --- examples/memtest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/memtest.py b/examples/memtest.py index 415992737..056b21f8f 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -30,7 +30,7 @@ "procedures_icd", "labevents", ], - dev=True, + dev=False, ) # %% # STEP 2: Apply StageNet mortality prediction task From 44c185f9f8fbc25499058b1463693dc3579b3909 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 16 Dec 2025 00:20:41 -0500 Subject: [PATCH 80/82] fix incorrect behaviour on notebook & make sure dask excption throw at dask --- pyhealth/datasets/base_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index f6ba72a85..27c1e46d5 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -388,6 +388,7 @@ def global_event_df(self) -> pl.LazyFrame: ) handle = client.compute(collection) progress(handle) + handle.result() # type: ignore self._global_event_df = ret_path return pl.scan_parquet( From 54db875dcb65077765a2e8683a1e6e0bd87201b4 Mon Sep 17 00:00:00 2001 From: John Wu Date: Wed, 17 Dec 2025 13:45:57 -0600 Subject: [PATCH 81/82] update on installation details and recommended settings for use with PyHealth due to new change in backend --- docs/install.rst | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/docs/install.rst b/docs/install.rst index 1498280bb..6cabd6ddb 100644 --- a/docs/install.rst +++ b/docs/install.rst @@ -1,6 +1,10 @@ Installation ============ +**Python Version Recommendation** + +We recommend using **Python 3.12** for optimal parallel processing and memory management performance. While PyHealth supports Python 3.8+, Python 3.12 provides significant improvements in these areas. + **Recommended Installation (Alpha Version)** We recommend installing the latest alpha version from PyPi, which offers significant improvements in performance: @@ -67,4 +71,40 @@ For example, if you use NVIDIA RTX A6000 as your GPU for training, you should in conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch. +---- + +**Platform-Specific Notes** + +**Windows Subsystem for Linux (WSL)** + +When using PyHealth on WSL, you **must** disable swap memory due to a bug in how Dask interacts with WSL's memory management. This prevents performance issues and potential crashes. + +**Method 1: Using WSL Settings App (Windows 11)** + +1. Open the WSL Settings app in Windows +2. Navigate to Memory and Processor settings +3. Set Swap size to 0 MB +4. Apply changes and restart WSL + +**Method 2: Manual Configuration** + +1. Open PowerShell as Administrator +2. Create or edit `%UserProfile%\.wslconfig` file +3. Add the following configuration: + +.. code-block:: ini + + [wsl2] + swap=0 + +4. Restart WSL by running in PowerShell: ``wsl --shutdown`` + +**Other Platforms** + +PyHealth should work without additional configuration on: + +- Linux (native) +- macOS +- Windows (with proper Python installation) + ---- \ No newline at end of file From 9e994bd7dc2954e79bbd02da96c57a5bb8d33de0 Mon Sep 17 00:00:00 2001 From: John Wu Date: Wed, 17 Dec 2025 15:58:22 -0600 Subject: [PATCH 82/82] additional clarifications here --- docs/install.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/install.rst b/docs/install.rst index 6cabd6ddb..087db7892 100644 --- a/docs/install.rst +++ b/docs/install.rst @@ -77,7 +77,7 @@ For example, if you use NVIDIA RTX A6000 as your GPU for training, you should in **Windows Subsystem for Linux (WSL)** -When using PyHealth on WSL, you **must** disable swap memory due to a bug in how Dask interacts with WSL's memory management. This prevents performance issues and potential crashes. +When using PyHealth on WSL, you **may need to** disable swap memory due to a bug in how Dask interacts with WSL's memory management when memory runs out. This prevents performance issues and potential crashes. **Method 1: Using WSL Settings App (Windows 11)**