-
Notifications
You must be signed in to change notification settings - Fork 562
Mem 9 - Memory Efficiency Update #739
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces a significant architectural refactoring to improve memory efficiency by migrating from in-memory dataset processing to streaming/disk-backed processing using litdata, dask, and narwhals. The changes modernize the data pipeline while maintaining the public API surface through a compatibility layer.
Key Changes:
- Replaced in-memory
SampleDatasetwith streaminglitdata.StreamingDatasetarchitecture - Migrated data processing from
polarsLazyFrames todaskDataFrames for better memory management - Introduced
SampleBuilderfor processor fitting andcreate_sample_datasethelper function - Updated all processors to accept
Iterableinstead ofListfor samples - Refactored
BaseDatasetto use caching and lazy evaluation throughout
Reviewed changes
Copilot reviewed 66 out of 66 changed files in this pull request and generated 19 comments.
Show a summary per file
| File | Description |
|---|---|
pyhealth/datasets/sample_dataset.py |
Complete rewrite: introduces SampleBuilder, streaming SampleDataset, InMemorySampleDataset, and create_sample_dataset helper |
pyhealth/datasets/base_dataset.py |
Major refactoring: added streaming parquet writer, dask-based data loading, improved caching with content-addressed directories |
pyhealth/datasets/utils.py |
Updated get_dataloader to work with streaming datasets by calling set_shuffle |
pyhealth/datasets/splitter.py |
Updated all split functions to use dataset.subset() instead of torch.utils.data.Subset |
pyhealth/processors/*.py |
Changed fit method signatures from List[Dict] to Iterable[Dict[str, Any]] |
pyhealth/models/*.py |
Updated documentation examples to use create_sample_dataset instead of SampleDataset |
tests/core/*.py |
Updated all tests to use create_sample_dataset API and adapted to streaming dataset behavior |
pyproject.toml |
Added new dependencies: dask, litdata, pyarrow, narwhals; updated polars version |
examples/memtest.py |
Refactored into structured example with __main__ guard |
examples/benchmark_perf/*.py |
Added new benchmarking scripts for performance testing with different worker configurations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def __iter__(self) -> Iterable[Dict[str, Any]]: # type: ignore | ||
| """Returns an iterator over all samples in the dataset. | ||
|
|
||
| Returns: | ||
| An iterator yielding processed sample dictionaries. | ||
| """ | ||
| return len(self.samples) | ||
| if self._shuffle: | ||
| shuffled_data = self._data[:] | ||
| random.shuffle(shuffled_data) | ||
| return iter(shuffled_data) | ||
| else: | ||
| return iter(self._data) |
Copilot
AI
Dec 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The InMemorySampleDataset.__iter__ method creates a copy of self._data and shuffles it in place when shuffle is enabled. However, the shuffled copy is only used for the current iteration. If the dataloader iterates multiple times (multiple epochs), it will need to call __iter__ again, but since set_shuffle is called only once before dataloader creation (in get_dataloader), subsequent iterations will use the same shuffled order. This doesn't provide per-epoch shuffling like PyTorch's standard DataLoader shuffle behavior.
…PyHealth due to new change in backend
This pull request introduces new benchmarking scripts for MIMIC-IV mortality prediction and refactors example scripts for clarity and maintainability. It also includes some library improvements and minor bug fixes. The main changes are grouped as follows:
1. New Benchmarking Scripts
benchmark_workers_1.pyandbenchmark_workers_4.py, to theexamples/benchmark_perfdirectory. These scripts measure dataset loading time, task processing time, cache sizes, and peak memory usage for MIMIC-IV mortality prediction with different numbers of workers. They also provide optional memory limit enforcement and detailed reporting. [1] [2]2. Example Script Refactoring and Addition
examples/memtest.pyto provide a more structured and readable example for using StageNet on MIMIC-IV, including data loading, task application, dataset splitting, model training, evaluation, and prediction inspection. The script is now wrapped in a__main__block and includes detailed output for each step.examples/benchmark_perf/memtest.pywith similar functionality to the refactoredmemtest.py, demonstrating the end-to-end pipeline for StageNet on MIMIC-IV.3. Library Improvements and API Changes
polarstonarwhalsin several dataset modules (bmd_hs.py,mimic3.py) to improve compatibility or performance. [1] [2]SampleDatasetimport inpyhealth/datasets/__init__.pyto also includeSampleBuilderandcreate_sample_dataset, making these utilities available at the package level.MIMIC4EHRDatasetconstructor to accept acache_dirparameter, allowing cache directory customization when instantiating the dataset. [1] [2]mimic4.pyto include a type ignore for compatibility with static analysis tools. [1] [2]4. Bug Fixes
_filter_by_time_range_fastinpyhealth/data/data.pyby ensuring that time comparisons usenp.datetime64for accurate filtering.covariate_label.pyso that thecal_datasetparameter in thecalibratemethod now expects anIterableDatasetinstead of aSubset, improving flexibility. [1] [2]