diff --git a/docs/install.rst b/docs/install.rst index 1498280bb..087db7892 100644 --- a/docs/install.rst +++ b/docs/install.rst @@ -1,6 +1,10 @@ Installation ============ +**Python Version Recommendation** + +We recommend using **Python 3.12** for optimal parallel processing and memory management performance. While PyHealth supports Python 3.8+, Python 3.12 provides significant improvements in these areas. + **Recommended Installation (Alpha Version)** We recommend installing the latest alpha version from PyPi, which offers significant improvements in performance: @@ -67,4 +71,40 @@ For example, if you use NVIDIA RTX A6000 as your GPU for training, you should in conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch. +---- + +**Platform-Specific Notes** + +**Windows Subsystem for Linux (WSL)** + +When using PyHealth on WSL, you **may need to** disable swap memory due to a bug in how Dask interacts with WSL's memory management when memory runs out. This prevents performance issues and potential crashes. + +**Method 1: Using WSL Settings App (Windows 11)** + +1. Open the WSL Settings app in Windows +2. Navigate to Memory and Processor settings +3. Set Swap size to 0 MB +4. Apply changes and restart WSL + +**Method 2: Manual Configuration** + +1. Open PowerShell as Administrator +2. Create or edit `%UserProfile%\.wslconfig` file +3. Add the following configuration: + +.. code-block:: ini + + [wsl2] + swap=0 + +4. Restart WSL by running in PowerShell: ``wsl --shutdown`` + +**Other Platforms** + +PyHealth should work without additional configuration on: + +- Linux (native) +- macOS +- Windows (with proper Python installation) + ---- \ No newline at end of file diff --git a/examples/benchmark_perf/benchmark_pandas.py b/examples/benchmark_perf/benchmark_pandas.py new file mode 100644 index 000000000..889fda14d --- /dev/null +++ b/examples/benchmark_perf/benchmark_pandas.py @@ -0,0 +1,407 @@ +""" +Benchmark script for MIMIC-IV mortality prediction using pandas +(analogous to PyHealth task). + +This benchmark mimics the MortalityPredictionStageNetMIMIC4 task: +1. Creates PATIENT-LEVEL samples (not visit-level) +2. Aggregates all admissions per patient +3. Combines ICD codes (diagnoses + procedures) across all visits +4. Extracts lab events in 10-dimensional vectors per timestamp +5. Calculates time intervals between consecutive admissions + +Lab Categories (10 dimensions): +- Sodium, Potassium, Chloride, Bicarbonate, Glucose +- Calcium, Magnesium, Anion Gap, Osmolality, Phosphate +""" + +import time +import os +import threading +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd +import psutil + + +PEAK_MEM_USAGE = 0 +SELF_PROC = psutil.Process(os.getpid()) + + +def track_mem(): + """Background thread to track peak memory usage.""" + global PEAK_MEM_USAGE + while True: + m = SELF_PROC.memory_info().rss + if m > PEAK_MEM_USAGE: + PEAK_MEM_USAGE = m + time.sleep(0.1) + + +# Lab item organization by category (matches MortalityPredictionStageNetMIMIC4) +LAB_CATEGORIES = { + "Sodium": ["50824", "52455", "50983", "52623"], + "Potassium": ["50822", "52452", "50971", "52610"], + "Chloride": ["50806", "52434", "50902", "52535"], + "Bicarbonate": ["50803", "50804"], + "Glucose": ["50809", "52027", "50931", "52569"], + "Calcium": ["50808", "51624"], + "Magnesium": ["50960"], + "Anion Gap": ["50868", "52500"], + "Osmolality": ["52031", "50964", "51701"], + "Phosphate": ["50970"], +} + +# Ordered list of category names (defines vector dimension order) +LAB_CATEGORY_NAMES = [ + "Sodium", + "Potassium", + "Chloride", + "Bicarbonate", + "Glucose", + "Calcium", + "Magnesium", + "Anion Gap", + "Osmolality", + "Phosphate", +] + +# Flat list of all lab item IDs for filtering +LABITEMS = [item for itemids in LAB_CATEGORIES.values() for item in itemids] + + +def process_patient_mortality( + subject_id: int, + patients_df: pd.DataFrame, + admissions_df: pd.DataFrame, + diagnoses_df: pd.DataFrame, + procedures_df: pd.DataFrame, + labevents_df: pd.DataFrame, +) -> Optional[Dict[str, Any]]: + """Process a single patient for mortality prediction task. + + Creates ONE patient-level sample by aggregating all admissions, + ICD codes, and lab events. + + Args: + subject_id: Patient ID + patients_df: Patient demographics + admissions_df: Admission records + diagnoses_df: Diagnosis ICD codes + procedures_df: Procedure ICD codes + labevents_df: Lab event measurements + + Returns: + Dictionary with patient sample or None if patient doesn't qualify + """ + # Get patient demographics + patient_demo = patients_df[patients_df["subject_id"] == subject_id] + if len(patient_demo) == 0: + return None + + # Skip if under 18 + anchor_age = patient_demo.iloc[0]["anchor_age"] + if anchor_age < 18: + return None + + # Get all admissions for this patient, sorted by time + patient_admissions = admissions_df[ + admissions_df["subject_id"] == subject_id + ].sort_values("admittime") + + if len(patient_admissions) < 1: + return None + + # Initialize aggregated data structures + all_icd_codes = [] # Nested list: [[visit1_codes], [visit2_codes], ...] + all_icd_times = [] # Time from previous admission per visit + all_lab_values = [] # List of 10D vectors + all_lab_times = [] # Time from admission start per measurement + + previous_admission_time = None + final_mortality = 0 + + # Process each admission + for _, admission in patient_admissions.iterrows(): + hadm_id = admission["hadm_id"] + admit_time = admission["admittime"] + discharge_time = admission["dischtime"] + + # Skip if invalid timestamps + if pd.isna(discharge_time) or discharge_time < admit_time: + continue + + # Calculate time from previous admission (hours) + if previous_admission_time is None: + time_from_previous = 0.0 + else: + time_from_previous = ( + admit_time - previous_admission_time + ).total_seconds() / 3600.0 + + previous_admission_time = admit_time + + # Update mortality label if this admission had mortality + if int(admission.get("hospital_expire_flag", 0)) == 1: + final_mortality = 1 + + # Get diagnosis codes for this admission + visit_diagnoses = diagnoses_df[diagnoses_df["hadm_id"] == hadm_id] + diagnoses_codes = visit_diagnoses["icd_code"].dropna().tolist() + + # Get procedure codes for this admission + visit_procedures = procedures_df[procedures_df["hadm_id"] == hadm_id] + procedures_codes = visit_procedures["icd_code"].dropna().tolist() + + # Combine diagnoses and procedures + visit_icd_codes = diagnoses_codes + procedures_codes + + if visit_icd_codes: + all_icd_codes.append(visit_icd_codes) + all_icd_times.append(time_from_previous) + + # Get lab events for this admission + admission_labs = labevents_df[ + (labevents_df["subject_id"] == subject_id) + & (labevents_df["hadm_id"] == hadm_id) + ] + + # Filter to relevant lab items + admission_labs = admission_labs[ + admission_labs["itemid"].astype(str).isin(LABITEMS) + ] + + if len(admission_labs) > 0: + # Parse storetime + admission_labs = admission_labs.copy() + admission_labs["storetime"] = pd.to_datetime(admission_labs["storetime"]) + + # Filter to valid times (before discharge) + admission_labs = admission_labs[ + admission_labs["storetime"] <= discharge_time + ] + + if len(admission_labs) > 0: + # Group by timestamp and create 10D vectors + unique_timestamps = sorted(admission_labs["storetime"].unique()) + + for lab_ts in unique_timestamps: + # Get all labs at this timestamp + ts_labs = admission_labs[admission_labs["storetime"] == lab_ts] + + # Create 10-dimensional vector + lab_vector = [] + for category_name in LAB_CATEGORY_NAMES: + category_itemids = LAB_CATEGORIES[category_name] + + # Find first matching value for this category + category_value = None + for itemid in category_itemids: + matching = ts_labs[ts_labs["itemid"].astype(str) == itemid] + if len(matching) > 0: + category_value = matching.iloc[0]["valuenum"] + break + + lab_vector.append(category_value) + + # Calculate time from admission start (hours) + time_from_admission = (lab_ts - admit_time).total_seconds() / 3600.0 + + all_lab_values.append(lab_vector) + all_lab_times.append(time_from_admission) + + # Skip if no lab events (required for this task) + if len(all_lab_values) == 0: + return None + + # Skip if no ICD codes + if len(all_icd_codes) == 0: + return None + + # Create patient-level sample + sample = { + "patient_id": subject_id, + "icd_codes": (all_icd_times, all_icd_codes), + "labs": (all_lab_times, all_lab_values), + "mortality": final_mortality, + "num_visits": len(all_icd_codes), + "num_lab_measurements": len(all_lab_values), + } + + return sample + + +def benchmark_mortality_prediction( + patients_df: pd.DataFrame, + admissions_df: pd.DataFrame, + diagnoses_df: pd.DataFrame, + procedures_df: pd.DataFrame, + labevents_df: pd.DataFrame, + n_patients: Optional[int] = None, +) -> Tuple[List[Dict[str, Any]], float]: + """ + Benchmark MIMIC-IV mortality prediction processing. + + Args: + patients_df: Patient demographics + admissions_df: Admissions dataframe + diagnoses_df: Diagnoses dataframe + procedures_df: Procedures dataframe + labevents_df: Lab events dataframe + n_patients: Number of patients to process (None = all patients) + + Returns: + Tuple of (list of samples, processing time in seconds) + """ + print("=" * 80) + print("BENCHMARK: Pandas Mortality Prediction (StageNet format)") + print("=" * 80) + + # Get patients to process + if n_patients is None: + patients_to_process = patients_df["subject_id"].tolist() + print(f"Processing all {len(patients_to_process)} patients...") + else: + patients_to_process = patients_df["subject_id"].head(n_patients).tolist() + print(f"Processing first {len(patients_to_process)} patients...") + + # Parse datetime columns + admissions_df = admissions_df.copy() + admissions_df["admittime"] = pd.to_datetime(admissions_df["admittime"]) + admissions_df["dischtime"] = pd.to_datetime(admissions_df["dischtime"]) + + # Start processing timer + start_time = time.perf_counter() + + samples = [] + processed_patients = 0 + + for subject_id in patients_to_process: + sample = process_patient_mortality( + subject_id, + patients_df, + admissions_df, + diagnoses_df, + procedures_df, + labevents_df, + ) + + if sample is not None: + samples.append(sample) + + processed_patients += 1 + if processed_patients % 100 == 0: + print(f"Processed {processed_patients} patients...") + + # End processing timer + processing_time = time.perf_counter() - start_time + + print("\nCompleted processing:") + print(f" - Total patients processed: {processed_patients}") + print(f" - Valid samples created: {len(samples)}") + print(f" - Processing time: {processing_time:.2f}s") + print("=" * 80) + + return samples, processing_time + + +def load_mimic_data(data_root: str = "/srv/local/data/MIMIC-IV/2.0/hosp"): + """Load MIMIC-IV tables needed for mortality prediction. + + Args: + data_root: Root directory for MIMIC-IV data + + Returns: + Tuple of dataframes: (patients, admissions, diagnoses, + procedures, labevents) + """ + print("Loading MIMIC-IV data tables...") + load_start = time.perf_counter() + + patients_df = pd.read_csv(f"{data_root}/patients.csv") + admissions_df = pd.read_csv(f"{data_root}/admissions.csv") + diagnoses_df = pd.read_csv(f"{data_root}/diagnoses_icd.csv") + procedures_df = pd.read_csv(f"{data_root}/procedures_icd.csv") + labevents_df = pd.read_csv(f"{data_root}/labevents.csv") + + load_time = time.perf_counter() - load_start + print(f"Data loaded in {load_time:.2f}s") + print(f" - Patients: {len(patients_df):,}") + print(f" - Admissions: {len(admissions_df):,}") + print(f" - Diagnoses: {len(diagnoses_df):,}") + print(f" - Procedures: {len(procedures_df):,}") + print(f" - Lab events: {len(labevents_df):,}") + print() + + return ( + patients_df, + admissions_df, + diagnoses_df, + procedures_df, + labevents_df, + ) + + +def main(): + """Main function to run the benchmark.""" + # Start memory tracking thread + mem_thread = threading.Thread(target=track_mem, daemon=True) + mem_thread.start() + + # Load data + data_root = "/srv/local/data/MIMIC-IV/2.0/hosp" + ( + patients_df, + admissions_df, + diagnoses_df, + procedures_df, + labevents_df, + ) = load_mimic_data(data_root) + + # Run benchmark (process all patients) + samples, processing_time = benchmark_mortality_prediction( + patients_df, + admissions_df, + diagnoses_df, + procedures_df, + labevents_df, + n_patients=None, # Change to a number to limit patients + ) + + # Get peak memory + peak_mem = PEAK_MEM_USAGE + + # Helper function for formatting size + def format_size(size_bytes): + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + # Save results + results_file = "benchmark_results_pandas.txt" + with open(results_file, "w") as f: + f.write("BENCHMARK RESULTS: Pandas Mortality Prediction\n") + f.write("=" * 80 + "\n") + f.write(f"Total samples: {len(samples)}\n") + f.write(f"Processing time: {processing_time:.2f}s\n") + f.write(f"Peak memory usage: {format_size(peak_mem)}\n") + f.write("=" * 80 + "\n") + + print(f"\n✓ Results saved to {results_file}") + print(f"Peak memory usage: {format_size(peak_mem)}") + + # Optional: Save samples for inspection + if samples: + print("\nExample sample (first patient):") + first_sample = samples[0] + print(f" Patient ID: {first_sample['patient_id']}") + print(f" Mortality: {first_sample['mortality']}") + print(f" Number of visits: {first_sample['num_visits']}") + print( + f" Number of lab measurements: " f"{first_sample['num_lab_measurements']}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/benchmark_perf/benchmark_workers_1.py b/examples/benchmark_perf/benchmark_workers_1.py new file mode 100644 index 000000000..4b1198861 --- /dev/null +++ b/examples/benchmark_perf/benchmark_workers_1.py @@ -0,0 +1,188 @@ +"""Benchmark script for MIMIC-IV mortality prediction with num_workers=1. + +This benchmark measures: +1. Time to load base dataset +2. Time to process task with num_workers=1 +3. Total processing time +4. Cache sizes +5. Peak memory usage (with optional memory limit) +""" + +import time +import os +import threading +from pathlib import Path +import psutil +from pyhealth.datasets import MIMIC4Dataset +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 + +try: + import resource + + HAS_RESOURCE = True +except ImportError: + HAS_RESOURCE = False + + +PEAK_MEM_USAGE = 0 +SELF_PROC = psutil.Process(os.getpid()) + + +def track_mem(): + """Background thread to track peak memory usage.""" + global PEAK_MEM_USAGE + while True: + m = SELF_PROC.memory_info().rss + if m > PEAK_MEM_USAGE: + PEAK_MEM_USAGE = m + time.sleep(0.1) + + +def set_memory_limit(max_memory_gb): + """Set hard memory limit for the process. + + Args: + max_memory_gb: Maximum memory in GB (e.g., 8 for 8GB) + + Note: + If limit is exceeded, the process will raise MemoryError. + Only works on Unix-like systems (Linux, macOS). + """ + if not HAS_RESOURCE: + print( + "Warning: resource module not available (Windows?). " + "Memory limit not enforced." + ) + return + + max_memory_bytes = int(max_memory_gb * 1024**3) + try: + resource.setrlimit(resource.RLIMIT_AS, (max_memory_bytes, max_memory_bytes)) + print(f"✓ Memory limit set to {max_memory_gb} GB") + except Exception as e: + print(f"Warning: Failed to set memory limit: {e}") + + +def get_directory_size(path): + """Calculate total size of a directory in bytes.""" + total = 0 + try: + for entry in Path(path).rglob("*"): + if entry.is_file(): + total += entry.stat().st_size + except Exception as e: + print(f"Error calculating size for {path}: {e}") + return total + + +def format_size(size_bytes): + """Format bytes to human-readable size.""" + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +def main(): + """Main benchmark function.""" + # Configuration + dev = True # Set to True for development/testing + enable_memory_limit = False # Set to True to enforce memory limit + max_memory_gb = 32 # Memory limit in GB (if enable_memory_limit=True) + + # Apply memory limit if enabled + if enable_memory_limit: + set_memory_limit(max_memory_gb) + + # Start memory tracking thread + mem_thread = threading.Thread(target=track_mem, daemon=True) + mem_thread.start() + + print("=" * 80) + print(f"BENCHMARK: num_workers=1, dev={dev}") + if enable_memory_limit: + print(f"Memory Limit: {max_memory_gb} GB (ENFORCED)") + else: + print("Memory Limit: None (unrestricted)") + print("=" * 80) + + # Define cache directories based on dev mode + cache_root = "./benchmark_cache/workers_1" + if dev: + cache_root += "_dev" + + # Track total time + total_start = time.time() # STEP 1: Load MIMIC-IV base dataset + print("\n[1/2] Loading MIMIC-IV base dataset...") + dataset_start = time.time() + + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + dev=dev, + cache_dir=f"{cache_root}/base_dataset", + ) + + dataset_time = time.time() - dataset_start + print(f"✓ Dataset loaded in {dataset_time:.2f} seconds") + + # STEP 2: Apply StageNet mortality prediction task with num_workers=1 + print("\n[2/2] Applying mortality prediction task (num_workers=1)...") + task_start = time.time() + + sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=1, + cache_dir=f"{cache_root}/task_samples", + ) + + task_time = time.time() - task_start + print(f"✓ Task processing completed in {task_time:.2f} seconds") + + # Measure cache sizes + print("\n[3/3] Measuring cache sizes...") + base_cache_dir = f"{cache_root}/base_dataset" + task_cache_dir = f"{cache_root}/task_samples" + + base_cache_size = get_directory_size(base_cache_dir) + task_cache_size = get_directory_size(task_cache_dir) + total_cache_size = base_cache_size + task_cache_size + + print(f"✓ Base dataset cache: {format_size(base_cache_size)}") + print(f"✓ Task samples cache: {format_size(task_cache_size)}") + print(f"✓ Total cache size: {format_size(total_cache_size)}") + + # Total time and peak memory + total_time = time.time() - total_start + peak_mem = PEAK_MEM_USAGE + + # Print summary + print("\n" + "=" * 80) + print("BENCHMARK RESULTS") + print("=" * 80) + print("Configuration:") + print(" - num_workers: 1") + print(f" - dev mode: {dev}") + print(f" - Total samples: {len(sample_dataset)}") + print("\nTiming:") + print(f" - Dataset loading: {dataset_time:.2f}s") + print(f" - Task processing: {task_time:.2f}s") + print(f" - Total time: {total_time:.2f}s") + print("\nCache Sizes:") + print(f" - Base dataset cache: {format_size(base_cache_size)}") + print(f" - Task samples cache: {format_size(task_cache_size)}") + print(f" - Total cache: {format_size(total_cache_size)}") + print("\nMemory:") + print(f" - Peak memory usage: {format_size(peak_mem)}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/benchmark_perf/benchmark_workers_4.py b/examples/benchmark_perf/benchmark_workers_4.py new file mode 100644 index 000000000..62e478561 --- /dev/null +++ b/examples/benchmark_perf/benchmark_workers_4.py @@ -0,0 +1,188 @@ +"""Benchmark script for MIMIC-IV mortality prediction with num_workers=4. + +This benchmark measures: +1. Time to load base dataset +2. Time to process task with num_workers=4 +3. Total processing time +4. Cache sizes +5. Peak memory usage (with optional memory limit) +""" + +import time +import os +import threading +from pathlib import Path +import psutil +from pyhealth.datasets import MIMIC4Dataset +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 + +try: + import resource + + HAS_RESOURCE = True +except ImportError: + HAS_RESOURCE = False + + +PEAK_MEM_USAGE = 0 +SELF_PROC = psutil.Process(os.getpid()) + + +def track_mem(): + """Background thread to track peak memory usage.""" + global PEAK_MEM_USAGE + while True: + m = SELF_PROC.memory_info().rss + if m > PEAK_MEM_USAGE: + PEAK_MEM_USAGE = m + time.sleep(0.1) + + +def set_memory_limit(max_memory_gb): + """Set hard memory limit for the process. + + Args: + max_memory_gb: Maximum memory in GB (e.g., 8 for 8GB) + + Note: + If limit is exceeded, the process will raise MemoryError. + Only works on Unix-like systems (Linux, macOS). + """ + if not HAS_RESOURCE: + print( + "Warning: resource module not available (Windows?). " + "Memory limit not enforced." + ) + return + + max_memory_bytes = int(max_memory_gb * 1024**3) + try: + resource.setrlimit(resource.RLIMIT_AS, (max_memory_bytes, max_memory_bytes)) + print(f"✓ Memory limit set to {max_memory_gb} GB") + except Exception as e: + print(f"Warning: Failed to set memory limit: {e}") + + +def get_directory_size(path): + """Calculate total size of a directory in bytes.""" + total = 0 + try: + for entry in Path(path).rglob("*"): + if entry.is_file(): + total += entry.stat().st_size + except Exception as e: + print(f"Error calculating size for {path}: {e}") + return total + + +def format_size(size_bytes): + """Format bytes to human-readable size.""" + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +def main(): + """Main benchmark function.""" + # Configuration + dev = True # Set to True for development/testing + enable_memory_limit = True # Set to True to enforce memory limit + max_memory_gb = 32 # Memory limit in GB (if enable_memory_limit=True) + + # Apply memory limit if enabled + if enable_memory_limit: + set_memory_limit(max_memory_gb) + + # Start memory tracking thread + mem_thread = threading.Thread(target=track_mem, daemon=True) + mem_thread.start() + + print("=" * 80) + print(f"BENCHMARK: num_workers=4, dev={dev}") + if enable_memory_limit: + print(f"Memory Limit: {max_memory_gb} GB (ENFORCED)") + else: + print("Memory Limit: None (unrestricted)") + print("=" * 80) + + # Define cache directories based on dev mode + cache_root = "./benchmark_cache/workers_4" + if dev: + cache_root += "_dev" + + # Track total time + total_start = time.time() # STEP 1: Load MIMIC-IV base dataset + print("\n[1/2] Loading MIMIC-IV base dataset...") + dataset_start = time.time() + + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + dev=dev, + cache_dir=f"{cache_root}/base_dataset", + ) + + dataset_time = time.time() - dataset_start + print(f"✓ Dataset loaded in {dataset_time:.2f} seconds") + + # STEP 2: Apply StageNet mortality prediction task with num_workers=4 + print("\n[2/2] Applying mortality prediction task (num_workers=4)...") + task_start = time.time() + + sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=4, + cache_dir=f"{cache_root}/task_samples", + ) + + task_time = time.time() - task_start + print(f"✓ Task processing completed in {task_time:.2f} seconds") + + # Measure cache sizes + print("\n[3/3] Measuring cache sizes...") + base_cache_dir = f"{cache_root}/base_dataset" + task_cache_dir = f"{cache_root}/task_samples" + + base_cache_size = get_directory_size(base_cache_dir) + task_cache_size = get_directory_size(task_cache_dir) + total_cache_size = base_cache_size + task_cache_size + + print(f"✓ Base dataset cache: {format_size(base_cache_size)}") + print(f"✓ Task samples cache: {format_size(task_cache_size)}") + print(f"✓ Total cache size: {format_size(total_cache_size)}") + + # Total time and peak memory + total_time = time.time() - total_start + peak_mem = PEAK_MEM_USAGE + + # Print summary + print("\n" + "=" * 80) + print("BENCHMARK RESULTS") + print("=" * 80) + print("Configuration:") + print(" - num_workers: 4") + print(f" - dev mode: {dev}") + print(f" - Total samples: {len(sample_dataset)}") + print("\nTiming:") + print(f" - Dataset loading: {dataset_time:.2f}s") + print(f" - Task processing: {task_time:.2f}s") + print(f" - Total time: {total_time:.2f}s") + print("\nCache Sizes:") + print(f" - Base dataset cache: {format_size(base_cache_size)}") + print(f" - Task samples cache: {format_size(task_cache_size)}") + print(f" - Total cache: {format_size(total_cache_size)}") + print("\nMemory:") + print(f" - Peak memory usage: {format_size(peak_mem)}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/benchmark_perf/memtest.py b/examples/benchmark_perf/memtest.py new file mode 100644 index 000000000..f8b92c71a --- /dev/null +++ b/examples/benchmark_perf/memtest.py @@ -0,0 +1,105 @@ +""" +Example of using StageNet for mortality prediction on MIMIC-IV. + +This example demonstrates: +1. Loading MIMIC-IV data +2. Applying the MortalityPredictionStageNetMIMIC4 task +3. Creating a SampleDataset with StageNet processors +4. Training a StageNet model +""" + +# %% +from pyhealth.datasets import ( + MIMIC4Dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models import StageNet +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 +from pyhealth.trainer import Trainer +import torch + +# %% STEP 1: Load MIMIC-IV base dataset +base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + dev=True, +) + +# %% # STEP 2: Apply StageNet mortality prediction task +sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=4, +) + +print(f"Total samples: {len(sample_dataset)}") +print(f"Input schema: {sample_dataset.input_schema}") +print(f"Output schema: {sample_dataset.output_schema}") + +# %% Inspect a sample +sample = next(iter(sample_dataset)) +print("\nSample structure:") +print(f" Patient ID: {sample['patient_id']}") +print(f"ICD Codes: {sample['icd_codes']}") +print(f" Labs shape: {len(sample['labs'][0])} timesteps") +print(f" Mortality: {sample['mortality']}") + +# %% STEP 3: Split dataset +train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] +) + +# Create dataloaders +train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) +val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) +test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) + +# %% STEP 4: Initialize StageNet model +model = StageNet( + dataset=sample_dataset, + embedding_dim=128, + chunk_size=128, + levels=3, + dropout=0.3, +) + +num_params = sum(p.numel() for p in model.parameters()) +print(f"\nModel initialized with {num_params} parameters") + +# %% STEP 5: Train the model +trainer = Trainer( + model=model, + device="cpu", # or "cpu" + metrics=["pr_auc", "roc_auc", "accuracy", "f1"], +) + +trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=5, + monitor="roc_auc", + optimizer_params={"lr": 1e-5}, +) + +# %% STEP 6: Evaluate on test set +results = trainer.evaluate(test_loader) +print("\nTest Results:") +for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + +# %% STEP 7: Inspect model predictions +sample_batch = next(iter(test_loader)) +with torch.no_grad(): + output = model(**sample_batch) + +print("\nSample predictions:") +print(f" Predicted probabilities: {output['y_prob'][:5]}") +print(f" True labels: {output['y_true'][:5]}") + +# %% diff --git a/examples/memtest.py b/examples/memtest.py index 8a63090e8..056b21f8f 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -1,32 +1,106 @@ +""" +Example of using StageNet for mortality prediction on MIMIC-IV. + +This example demonstrates: +1. Loading MIMIC-IV data +2. Applying the MortalityPredictionStageNetMIMIC4 task +3. Creating a SampleDataset with StageNet processors +4. Training a StageNet model +""" + # %% -import psutil, os, time, threading -PEAK_MEM_USAGE = 0 -SELF_PROC = psutil.Process(os.getpid()) +if __name__ == "__main__": + from pyhealth.datasets import ( + MIMIC4Dataset, + get_dataloader, + split_by_patient, + ) + from pyhealth.models import StageNet + from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 + from pyhealth.trainer import Trainer + import torch -def track_mem(): - global PEAK_MEM_USAGE - while True: - m = SELF_PROC.memory_info().rss - if m > PEAK_MEM_USAGE: - PEAK_MEM_USAGE = m - time.sleep(0.1) + # %% STEP 1: Load MIMIC-IV base dataset + base_dataset = MIMIC4Dataset( + ehr_root="/home/logic/physionet.org/files/mimiciv/3.1/", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + dev=False, + ) -threading.Thread(target=track_mem, daemon=True).start() -print(f"[MEM] start={PEAK_MEM_USAGE / (1024**3)} GB") + # %% # STEP 2: Apply StageNet mortality prediction task + sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=4, + ) -# %% -from pyhealth.datasets import MIMIC4Dataset -DATASET_DIR = "/home/logic/physionet.org/files/mimiciv/3.1" -dataset = MIMIC4Dataset( - ehr_root=DATASET_DIR, - ehr_tables=[ - "patients", - "admissions", - "diagnoses_icd", - "procedures_icd", - "prescriptions", - "labevents", - ], -) -print(f"[MEM] __init__={PEAK_MEM_USAGE / (1024**3):.3f} GB") -# %% + print(f"Total samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + # %% Inspect a sample + sample = next(iter(sample_dataset)) + print("\nSample structure:") + print(f" Patient ID: {sample['patient_id']}") + print(f"ICD Codes: {sample['icd_codes']}") + print(f" Labs shape: {len(sample['labs'][0])} timesteps") + print(f" Mortality: {sample['mortality']}") + + # %% STEP 3: Split dataset + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + + # Create dataloaders + train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) + + # %% STEP 4: Initialize StageNet model + model = StageNet( + dataset=sample_dataset, + embedding_dim=128, + chunk_size=128, + levels=3, + dropout=0.3, + ) + + num_params = sum(p.numel() for p in model.parameters()) + print(f"\nModel initialized with {num_params} parameters") + + # %% STEP 5: Train the model + trainer = Trainer( + model=model, + device="cpu", # or "cpu" + metrics=["pr_auc", "roc_auc", "accuracy", "f1"], + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=5, + monitor="roc_auc", + optimizer_params={"lr": 1e-5}, + ) + + # %% STEP 6: Evaluate on test set + results = trainer.evaluate(test_loader) + print("\nTest Results:") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + # %% STEP 7: Inspect model predictions + sample_batch = next(iter(test_loader)) + with torch.no_grad(): + output = model(**sample_batch) + + print("\nSample predictions:") + print(f" Predicted probabilities: {output['y_prob'][:5]}") + print(f" True labels: {output['y_true'][:5]}") + + # %% diff --git a/pyhealth/calib/predictionset/covariate/covariate_label.py b/pyhealth/calib/predictionset/covariate/covariate_label.py index e3aa0e293..602944b15 100644 --- a/pyhealth/calib/predictionset/covariate/covariate_label.py +++ b/pyhealth/calib/predictionset/covariate/covariate_label.py @@ -14,7 +14,7 @@ import numpy as np import torch -from torch.utils.data import Subset +from torch.utils.data import IterableDataset from pyhealth.calib.base_classes import SetPredictor from pyhealth.calib.calibration.kcal.kde import RBFKernelMean @@ -306,7 +306,7 @@ def __init__( def calibrate( self, - cal_dataset: Subset, + cal_dataset: IterableDataset, cal_embeddings: Optional[np.ndarray] = None, test_embeddings: Optional[np.ndarray] = None, ): diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index 2a6d3a45c..14b1b526c 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -152,9 +152,9 @@ def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime] start_idx = 0 end_idx = len(ts_col) if start is not None: - start_idx = np.searchsorted(ts_col, start, side="left") + start_idx = np.searchsorted(ts_col, np.datetime64(start, "ms"), side="left") if end is not None: - end_idx = np.searchsorted(ts_col, end, side="right") + end_idx = np.searchsorted(ts_col, np.datetime64(end, "ms"), side="right") return df.slice(start_idx, end_idx - start_idx) def _filter_by_event_type_regular(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 7d6a65f16..5176fdb42 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -61,7 +61,7 @@ def __init__(self, *args, **kwargs): from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset -from .sample_dataset import SampleDataset +from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset from .bmd_hs import BMDHSDataset diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 5da5ef0db..27c1e46d5 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -4,20 +4,38 @@ from abc import ABC from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from typing import Dict, Iterator, List, Optional +from typing import Dict, Iterator, List, Optional, Any, Callable +import functools +import operator from urllib.parse import urlparse, urlunparse - +from urllib.request import urlretrieve +import json +import uuid +import platformdirs +import tempfile +import multiprocessing + +import litdata +from litdata.streaming.item_loader import ParquetLoader +import pyarrow as pa +import pyarrow.csv as pv +import pyarrow.parquet as pq +import pandas as pd import polars as pl import requests from tqdm import tqdm +import dask.dataframe as dd +from dask.distributed import Client, LocalCluster, progress +import narwhals as nw from ..data import Patient from ..tasks import BaseTask from ..processors.base_processor import FeatureProcessor from .configs import load_yaml_config -from .sample_dataset import SampleDataset -from .utils import _convert_for_cache, _restore_from_cache +from .sample_dataset import SampleDataset, SampleBuilder +# Set logging level for distributed to ERROR to reduce verbosity +logging.getLogger("distributed").setLevel(logging.ERROR) logger = logging.getLogger(__name__) @@ -55,27 +73,24 @@ def path_exists(path: str) -> bool: else: return Path(path).exists() - -def scan_csv_gz_or_csv_tsv(path: str) -> pl.LazyFrame: +def _csv_tsv_gz_path(path: str) -> str: """ - Scan a CSV.gz, CSV, TSV.gz, or TSV file and returns a LazyFrame. - It will fall back to the other extension if not found. + Get the path to the file, trying the original path first, then the alternative path + by switching between .csv.gz, .csv, .tsv.gz, and .tsv extensions. Args: - path (str): URL or local path to a .csv, .csv.gz, .tsv, or .tsv.gz file + path (str): Original file path. Returns: - pl.LazyFrame: The LazyFrame for the CSV.gz, CSV, TSV.gz, or TSV file. - """ - - def scan_file(file_path: str) -> pl.LazyFrame: - separator = "\t" if ".tsv" in file_path else "," - return pl.scan_csv(file_path, separator=separator, infer_schema=False) + str: The file path that exists. + Raises: + FileNotFoundError: If neither the original nor the alternative path exists. + ValueError: If the path does not have an expected extension. + """ if path_exists(path): - return scan_file(path) + return path - # Try the alternative extension if path.endswith(".csv.gz"): alt_path = path[:-3] # Remove .gz -> try .csv elif path.endswith(".csv"): @@ -85,14 +100,93 @@ def scan_file(file_path: str) -> pl.LazyFrame: elif path.endswith(".tsv"): alt_path = f"{path}.gz" # Add .gz -> try .tsv.gz else: - raise FileNotFoundError(f"Path does not have expected extension: {path}") - + raise ValueError(f"Path does not have expected extension: {path}") + if path_exists(alt_path): - logger.info(f"Original path does not exist. Using alternative: {alt_path}") - return scan_file(alt_path) - + return alt_path + raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") +def _uncollate(x: list[Any]) -> Any: + return x[0] if isinstance(x, list) and len(x) == 1 else x + + +class _ParquetWriter: + """ + Stream-write rows into a Parquet file in chunked (row-group) fashion. + + Usage: + writer = StreamingParquetWriter(Path("out.parquet"), schema, chunk_size=10000) + writer.append({"id": 1, "val": 3.14}) + writer.append({"id": 2, "val": 1.23}) + writer.close() + """ + + def __init__(self, path: Path | str, schema: pa.Schema, chunk_size: int = 8_192): + """ + Args: + path: output Parquet file path + schema: pyarrow.Schema (required) + chunk_size: flush buffer every N rows + """ + self.path = Path(path) + self.schema = schema + self.chunk_size = chunk_size + + if self.schema is None: + raise ValueError( + "schema must be provided — no automatic inference allowed." + ) + + self._writer: pq.ParquetWriter | None = None + self._buffer: list[dict] = [] + self._closed = False + + # -------------------------------------------------------------- + # Public API + # -------------------------------------------------------------- + def append(self, row: dict) -> None: + """Append a single row (a Python dict).""" + if self._closed: + raise RuntimeError("Cannot append to a closed StreamingParquetWriter") + + self._buffer.append(row) + if len(self._buffer) >= self.chunk_size: + self.flush() + + def flush(self) -> None: + """Flush buffered rows into a Parquet row-group.""" + if not self._buffer: + return + + # Convert list[dict] → Arrow RecordBatch + batch = pa.RecordBatch.from_pylist(self._buffer, schema=self.schema) + + # Lazy-initialize writer + if self._writer is None: + self._writer = pq.ParquetWriter(self.path, self.schema) + + self._writer.write_batch(batch) + self._buffer.clear() + + def close(self) -> None: + """Flush and close the Parquet writer.""" + if self._closed: + return + self.flush() + if self._writer is not None: + self._writer.close() + self._closed = True + + # -------------------------------------------------------------- + # Context manager support + # -------------------------------------------------------------- + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + class BaseDataset(ABC): """Abstract base class for all PyHealth datasets. @@ -112,6 +206,7 @@ def __init__( tables: List[str], dataset_name: Optional[str] = None, config_path: Optional[str] = None, + cache_dir: str | Path | None = None, dev: bool = False, ): """Initializes the BaseDataset. @@ -129,123 +224,236 @@ def __init__( self.root = root self.tables = tables self.dataset_name = dataset_name or self.__class__.__name__ - self.config = load_yaml_config(config_path) self.dev = dev + self.config = load_yaml_config(config_path) if config_path else None logger.info( f"Initializing {self.dataset_name} dataset from {self.root} (dev mode: {self.dev})" ) - self.global_event_df = self.load_data() - # Cached attributes - self._collected_global_event_df = None + self._cache_dir = cache_dir + self._global_event_df = None self._unique_patient_ids = None @property - def collected_global_event_df(self) -> pl.DataFrame: - """Collects and returns the global event data frame. + def cache_dir(self) -> Path: + """Returns the cache directory path. + Returns: + Path: The cache directory path. + """ + if self._cache_dir is None: + id_str = json.dumps( + { + "root": self.root, + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + }, + sort_keys=True, + ) + cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str( + uuid.uuid5(uuid.NAMESPACE_DNS, id_str) + ) + cache_dir.mkdir(parents=True, exist_ok=True) + print(f"No cache_dir provided. Using default cache dir: {cache_dir}") + self._cache_dir = cache_dir + else: + # Ensure the explicitly provided cache_dir exists + cache_dir = Path(self._cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + self._cache_dir = cache_dir + return Path(self._cache_dir) + + @property + def temp_dir(self) -> Path: + return self.cache_dir / "temp" + + def _scan_csv_tsv_gz(self, table_name: str, source_path: str | None = None) -> dd.DataFrame: + """Scans a CSV/TSV file (possibly gzipped) and returns a Dask DataFrame. + + If the cached Parquet file does not exist, it converts the source CSV/TSV file + to Parquet and saves it to the cache. + + Args: + table_name (str): The name of the table. + source_path (str | None): The source CSV/TSV file path. If None, assumes the + Parquet file already exists in the cache. Returns: - pl.DataFrame: The collected global event data frame. + dd.DataFrame: The Dask DataFrame loaded from the cached Parquet file. + + Raises: + FileNotFoundError: If source_path is None and the cached Parquet file does not exist; + or if neither the original nor the alternative path of source_path exists. + ValueError: If the path does not have an expected extension. """ - if self._collected_global_event_df is None: - logger.info("Collecting global event dataframe...") + # Ensure the tables cache directory exists + (self.temp_dir / "tables").mkdir(parents=True, exist_ok=True) + ret_path = str(self.temp_dir / "tables" / f"{table_name}.parquet") + + if not path_exists(ret_path): + if source_path is None: + raise FileNotFoundError( + f"Table {table_name} not found in cache and no source_path provided." + ) - # Collect the dataframe - with dev mode limiting if applicable - df = self.global_event_df - # TODO: dev doesn't seem to improve the speed / memory usage - if self.dev: - # Limit the number of patients in dev mode - logger.info("Dev mode enabled: limiting to 1000 patients") - limited_patients = df.select(pl.col("patient_id")).unique().limit(1000) - df = df.join(limited_patients, on="patient_id", how="inner") - - self._collected_global_event_df = df.collect() - - # Profile the Polars collect() operation (commented out by default) - # self._collected_global_event_df, profile = df.profile() - # profile = profile.with_columns([ - # (pl.col("end") - pl.col("start")).alias("duration"), - # ]) - # profile = profile.with_columns([ - # (pl.col("duration") / profile["duration"].sum() * 100).alias("percentage") - # ]) - # profile = profile.sort("duration", descending=True) - # with pl.Config() as cfg: - # cfg.set_tbl_rows(-1) - # cfg.set_fmt_str_lengths(200) - # print(profile) + source_path = _csv_tsv_gz_path(source_path) + + if is_url(source_path): + local_filename = os.path.basename(source_path) + download_dir = self.temp_dir / "downloads" + download_dir.mkdir(parents=True, exist_ok=True) + local_path = download_dir / local_filename + if not local_path.exists(): + logger.info(f"Downloading {source_path} to {local_path}") + urlretrieve(source_path, local_path) + source_path = str(local_path) + + # Determine delimiter based on file extension + delimiter = ( + "\t" + if source_path.endswith(".tsv") or source_path.endswith(".tsv.gz") + else "," + ) - logger.info( - f"Collected dataframe with shape: {self._collected_global_event_df.shape}" + # Always infer schema as string to avoid incorrect type inference + schema_reader = pv.open_csv( + source_path, + read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB + parse_options=pv.ParseOptions(delimiter=delimiter), ) + schema = pa.schema( + [pa.field(name, pa.string()) for name in schema_reader.schema.names] + ) + + # Convert CSV/TSV to Parquet + csv_reader = pv.open_csv( + source_path, + read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB + parse_options=pv.ParseOptions(delimiter=delimiter), + convert_options=pv.ConvertOptions(column_types=schema), + ) + with pq.ParquetWriter(ret_path, csv_reader.schema) as writer: + for batch in csv_reader: + writer.write_batch(batch) + + df: dd.DataFrame = dd.read_parquet( + ret_path, + split_row_groups=True, # type: ignore + blocksize="64MB", + ) + return df.replace("", pd.NA) # Replace empty strings with NaN + + @property + def global_event_df(self) -> pl.LazyFrame: + """Returns the path to the cached event dataframe. - return self._collected_global_event_df + Returns: + Path: The path to the cached event dataframe. + """ + if not multiprocessing.current_process().name == "MainProcess": + logger.warning( + "global_event_df property accessed from a non-main process. This may lead to unexpected behavior.\n" + + "Consider use __name__ == '__main__' guard when using multiprocessing." + ) + return None # type: ignore + + if self._global_event_df is None: + ret_path = self.cache_dir / "global_event_df.parquet" + if not ret_path.exists(): + # TODO: auto select processes=True/False based on if it's in jupyter notebook + # The processes=True will crash in jupyter notebook. + # TODO: make the n_workers configurable + with LocalCluster( + n_workers=1, + threads_per_worker=1, + processes=False, + ) as cluster: + with Client(cluster) as client: + df: dd.DataFrame = self.load_data() + if self.dev: + logger.info("Dev mode enabled: limiting to 1000 patients") + patients = ( + df["patient_id"].unique().head(1000).tolist() + ) + filter = df["patient_id"].isin(patients) + df = df[filter] + + logger.info(f"Caching event dataframe to {ret_path}...") + collection = df.sort_values("patient_id").to_parquet( + ret_path, + write_index=False, + compute=False, + ) + handle = client.compute(collection) + progress(handle) + handle.result() # type: ignore + self._global_event_df = ret_path + + return pl.scan_parquet( + self._global_event_df, + low_memory=True, + ) - def load_data(self) -> pl.LazyFrame: + def load_data(self) -> dd.DataFrame: """Loads data from the specified tables. Returns: - pl.LazyFrame: A concatenated lazy frame of all tables. + dd.DataFrame: A concatenated lazy frame of all tables. """ frames = [self.load_table(table.lower()) for table in self.tables] - return pl.concat(frames, how="diagonal") + return dd.concat(frames, axis=0, join="outer") - def load_table(self, table_name: str) -> pl.LazyFrame: + def load_table(self, table_name: str) -> dd.DataFrame: """Loads a table and processes joins if specified. Args: table_name (str): The name of the table to load. Returns: - pl.LazyFrame: The processed lazy frame for the table. + dd.DataFrame: The processed Dask dataframe for the table. Raises: ValueError: If the table is not found in the config. FileNotFoundError: If the CSV file for the table or join is not found. """ + assert self.config is not None, "Config must be provided to load tables" + if table_name not in self.config.tables: raise ValueError(f"Table {table_name} not found in config") - def _to_lower(col_name: str) -> str: - lower_name = col_name.lower() - if lower_name != col_name: - logger.warning( - "Renaming column %s to lowercase %s", col_name, lower_name - ) - return lower_name - table_cfg = self.config.tables[table_name] csv_path = f"{self.root}/{table_cfg.file_path}" csv_path = clean_path(csv_path) logger.info(f"Scanning table: {table_name} from {csv_path}") - df = scan_csv_gz_or_csv_tsv(csv_path) + df = self._scan_csv_tsv_gz(table_name, csv_path) # Convert column names to lowercase before calling preprocess_func - df = df.rename(_to_lower) + df = df.rename(columns=str.lower) # Check if there is a preprocessing function for this table + preprocess_func: Optional[Callable[[nw.LazyFrame], nw.LazyFrame]] preprocess_func = getattr(self, f"preprocess_{table_name}", None) if preprocess_func is not None: logger.info( f"Preprocessing table: {table_name} with {preprocess_func.__name__}" ) - df = preprocess_func(df) + df = preprocess_func(nw.from_native(df)).to_native() # type: ignore # Handle joins - for join_cfg in table_cfg.join: + for i, join_cfg in enumerate(table_cfg.join): other_csv_path = f"{self.root}/{join_cfg.file_path}" other_csv_path = clean_path(other_csv_path) logger.info(f"Joining with table: {other_csv_path}") - join_df = scan_csv_gz_or_csv_tsv(other_csv_path) - join_df = join_df.rename(_to_lower) + join_df = self._scan_csv_tsv_gz(f"{table_name}_join_{i}", other_csv_path) + join_df = join_df.rename(columns=str.lower) join_key = join_cfg.on columns = join_cfg.columns how = join_cfg.how - df = df.join(join_df.select([join_key] + columns), on=join_key, how=how) + df: dd.DataFrame = df.merge(join_df[[join_key] + columns], on=join_key, how=how) patient_id_col = table_cfg.patient_id timestamp_col = table_cfg.timestamp @@ -256,38 +464,38 @@ def _to_lower(col_name: str) -> str: if timestamp_col: if isinstance(timestamp_col, list): # Concatenate all timestamp parts in order with no separator - combined_timestamp = pl.concat_str( - [pl.col(col) for col in timestamp_col] - ).str.strptime(pl.Datetime, format=timestamp_format, strict=True) - timestamp_expr = combined_timestamp + timestamp_series: dd.Series = functools.reduce( + operator.add, (df[col].astype(str) for col in timestamp_col) + ) else: # Single timestamp column - timestamp_expr = pl.col(timestamp_col).str.strptime( - pl.Datetime, format=timestamp_format, strict=True - ) + timestamp_series: dd.Series = df[timestamp_col].astype(str) + timestamp_series: dd.Series = dd.to_datetime( + timestamp_series, + format=timestamp_format, + errors="raise", + ) + df: dd.DataFrame = df.assign(timestamp=timestamp_series.astype("datetime64[ms]")) else: - timestamp_expr = pl.lit(None, dtype=pl.Datetime) + df: dd.DataFrame = df.assign(timestamp=pd.NaT) + # If patient_id_col is None, use row index as patient_id - patient_id_expr = ( - pl.col(patient_id_col).cast(pl.Utf8) - if patient_id_col - else pl.int_range(0, pl.count()).cast(pl.Utf8) - ) - base_columns = [ - patient_id_expr.alias("patient_id"), - pl.lit(table_name).cast(pl.Utf8).alias("event_type"), - # ms should be sufficient for most cases - timestamp_expr.cast(pl.Datetime(time_unit="ms")).alias("timestamp"), - ] + if patient_id_col: + df: dd.DataFrame = df.assign(patient_id=df[patient_id_col].astype(str)) + else: + df: dd.DataFrame = df.reset_index(drop=True) + df: dd.DataFrame = df.assign(patient_id=df.index.astype(str)) + - # Flatten attribute columns with event_type prefix - attribute_columns = [ - pl.col(attr.lower()).alias(f"{table_name}/{attr}") - for attr in attribute_cols - ] + df: dd.DataFrame = df.assign(event_type=table_name) - event_frame = df.select(base_columns + attribute_columns) + rename_attr = {attr.lower(): f"{table_name}/{attr}" for attr in attribute_cols} + df: dd.DataFrame = df.rename(columns=rename_attr) + + attr_cols = [rename_attr[attr.lower()] for attr in attribute_cols] + final_cols = ["patient_id", "event_type", "timestamp"] + attr_cols + event_frame = df[final_cols] return event_frame @@ -300,8 +508,9 @@ def unique_patient_ids(self) -> List[str]: """ if self._unique_patient_ids is None: self._unique_patient_ids = ( - self.collected_global_event_df.select("patient_id") + self.global_event_df.select("patient_id") .unique() + .collect(engine="streaming") .to_series() .to_list() ) @@ -323,8 +532,11 @@ def get_patient(self, patient_id: str) -> Patient: assert ( patient_id in self.unique_patient_ids ), f"Patient {patient_id} not found in dataset" - df = self.collected_global_event_df.filter(pl.col("patient_id") == patient_id) - return Patient(patient_id=patient_id, data_source=df) + + data_source = self.global_event_df.filter( + pl.col("patient_id") == patient_id + ).collect(engine="streaming") + return Patient(patient_id=patient_id, data_source=data_source) def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: """Yields Patient objects for each unique patient in the dataset. @@ -333,20 +545,30 @@ def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: Iterator[Patient]: An iterator over Patient objects. """ if df is None: - df = self.collected_global_event_df - grouped = df.group_by("patient_id") + df = self.global_event_df + patient_ids = ( + df.select("patient_id") + .unique(maintain_order=True) + .collect(engine="streaming") + .to_series() + ) - for patient_id, patient_df in grouped: - patient_id = patient_id[0] + for patient_id in patient_ids: + patient_df = df.filter(pl.col("patient_id") == patient_id).collect( + engine="streaming" + ) yield Patient(patient_id=patient_id, data_source=patient_df) def stats(self) -> None: """Prints statistics about the dataset.""" - df = self.collected_global_event_df + stats = self.global_event_df.select( + pl.len().alias("n_events"), + pl.col("patient_id").n_unique().alias("n_patients"), + ).collect(engine="streaming") print(f"Dataset: {self.dataset_name}") print(f"Dev mode: {self.dev}") - print(f"Number of patients: {df['patient_id'].n_unique()}") - print(f"Number of events: {df.height}") + print(f"Number of patients: {stats['n_patients'][0]}") + print(f"Number of events: {stats['n_events'][0]}") @property def default_task(self) -> Optional[BaseTask]: @@ -361,7 +583,7 @@ def set_task( self, task: Optional[BaseTask] = None, num_workers: int = 1, - cache_dir: Optional[str] = None, + cache_dir: str | Path | None = None, cache_format: str = "parquet", input_processors: Optional[Dict[str, FeatureProcessor]] = None, output_processors: Optional[Dict[str, FeatureProcessor]] = None, @@ -375,8 +597,7 @@ def set_task( multi-threading may not speed up the task function. cache_dir (Optional[str]): Directory to cache processed samples. Default is None (no caching). - cache_format (str): Format for caching ('parquet' or 'pickle'). - Default is 'parquet'. + cache_format (str): Deprecated. Only "parquet" is supported now. input_processors (Optional[Dict[str, FeatureProcessor]]): Pre-fitted input processors. If provided, these will be used instead of creating new ones from task's input_schema. Defaults to None. @@ -390,110 +611,96 @@ def set_task( Raises: AssertionError: If no default task is found and task is None. """ + if not multiprocessing.current_process().name == "MainProcess": + logger.warning( + "set_task method accessed from a non-main process. This may lead to unexpected behavior.\n" + + "Consider use __name__ == '__main__' guard when using multiprocessing." + ) + return None # type: ignore + if task is None: assert self.default_task is not None, "No default tasks found" task = self.default_task + if cache_format != "parquet": + logger.warning("Only 'parquet' cache_format is supported now. ") + logger.info( f"Setting task {task.task_name} for {self.dataset_name} base dataset..." ) - # Check for cached data if cache_dir is provided - samples = None - if cache_dir is not None: - cache_filename = f"{task.task_name}.{cache_format}" - cache_path = Path(cache_dir) / cache_filename - if cache_path.exists(): - logger.info(f"Loading cached samples from {cache_path}") - try: - if cache_format == "parquet": - # Load samples from parquet file - cached_df = pl.read_parquet(cache_path) - samples = [ - _restore_from_cache(row) for row in cached_df.to_dicts() - ] - elif cache_format == "pickle": - # Load samples from pickle file - with open(cache_path, "rb") as f: - samples = pickle.load(f) - else: - msg = f"Unsupported cache format: {cache_format}" - raise ValueError(msg) - logger.info(f"Loaded {len(samples)} cached samples") - except Exception as e: - logger.warning( - "Failed to load cached data: %s. Regenerating...", - e, + if cache_dir is None: + cache_dir = self.cache_dir / "tasks" / task.task_name + cache_dir.mkdir(parents=True, exist_ok=True) + else: + # Ensure the explicitly provided cache_dir exists + cache_dir = Path(cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + + path = Path(cache_dir) + + # Check if index.json exists to verify cache integrity, this + # is the standard file for litdata.StreamingDataset + if not (path / "index.json").exists(): + global_event_df = task.pre_filter(self.global_event_df) + schema = pa.schema([("sample", pa.binary())]) + with tempfile.TemporaryDirectory() as tmp_dir: + # Create Parquet file with samples + logger.info(f"Applying task transformations on data...") + with _ParquetWriter(f"{tmp_dir}/samples.parquet", schema) as writer: + # TODO: this can be further optimized. + patient_ids = ( + global_event_df.select("patient_id") + .unique() + .collect(engine="streaming") + .to_series() ) - samples = None - - # Generate samples if not loaded from cache - if samples is None: - logger.info(f"Generating samples with {num_workers} worker(s)...") - filtered_global_event_df = task.pre_filter(self.collected_global_event_df) - samples = [] - - if num_workers == 1: - # single-threading (by default) - for patient in tqdm( - self.iter_patients(filtered_global_event_df), - total=filtered_global_event_df["patient_id"].n_unique(), - desc=(f"Generating samples for {task.task_name} " "with 1 worker"), - smoothing=0, - ): - samples.extend(task(patient)) - else: - # multi-threading (not recommended) - logger.info( - f"Generating samples for {task.task_name} with " - f"{num_workers} workers" + for patient_id in tqdm(patient_ids): + patient_df = global_event_df.filter( + pl.col("patient_id") == patient_id + ).collect(engine="streaming") + patient = Patient(patient_id=patient_id, data_source=patient_df) + for sample in task(patient): + writer.append({"sample": pickle.dumps(sample)}) + litdata.index_parquet_dataset(tmp_dir) + + # Build processors and fit on the dataset + logger.info(f"Fitting processors on the dataset...") + dataset = litdata.StreamingDataset( + tmp_dir, + item_loader=ParquetLoader(), + transform=lambda x: pickle.loads(x["sample"]), ) - patients = list(self.iter_patients(filtered_global_event_df)) - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(task, patient) for patient in patients] - for future in tqdm( - as_completed(futures), - total=len(futures), - desc=( - f"Collecting samples for {task.task_name} " - f"from {num_workers} workers" - ), - ): - samples.extend(future.result()) - - # Cache the samples if cache_dir is provided - if cache_dir is not None: - cache_path = Path(cache_dir) / cache_filename - cache_path.parent.mkdir(parents=True, exist_ok=True) - logger.info(f"Caching samples to {cache_path}") - try: - if cache_format == "parquet": - # Save samples as parquet file - samples_for_cache = [ - _convert_for_cache(sample) for sample in samples - ] - samples_df = pl.DataFrame(samples_for_cache) - samples_df.write_parquet(cache_path) - elif cache_format == "pickle": - # Save samples as pickle file - with open(cache_path, "wb") as f: - pickle.dump(samples, f) - else: - msg = f"Unsupported cache format: {cache_format}" - raise ValueError(msg) - logger.info(f"Successfully cached {len(samples)} samples") - except Exception as e: - logger.warning(f"Failed to cache samples: {e}") - - sample_dataset = SampleDataset( - samples, - input_schema=task.input_schema, - output_schema=task.output_schema, + builder = SampleBuilder( + input_schema=task.input_schema, # type: ignore + output_schema=task.output_schema, # type: ignore + input_processors=input_processors, + output_processors=output_processors, + ) + builder.fit(dataset) + builder.save(str(path / "schema.pkl")) + + # Apply processors and save final samples to cache_dir + logger.info(f"Processing samples and saving to {path}...") + dataset = litdata.StreamingDataset( + tmp_dir, + item_loader=ParquetLoader(), + ) + litdata.optimize( + fn=builder.transform, + inputs=litdata.StreamingDataLoader( + dataset, + batch_size=1, + collate_fn=_uncollate, + ), + output_dir=str(path), + chunk_bytes="64MB", + num_workers=num_workers, + ) + logger.info(f"Cached processed samples to {path}") + + return SampleDataset( + path=str(path), dataset_name=self.dataset_name, - task_name=task, - input_processors=input_processors, - output_processors=output_processors, + task_name=task.task_name, ) - - logger.info(f"Generated {len(samples)} samples for task {task.task_name}") - return sample_dataset diff --git a/pyhealth/datasets/bmd_hs.py b/pyhealth/datasets/bmd_hs.py index 183c23766..30c706132 100644 --- a/pyhealth/datasets/bmd_hs.py +++ b/pyhealth/datasets/bmd_hs.py @@ -1,7 +1,7 @@ import logging from pathlib import Path from typing import Optional -import polars as pl +import narwhals as pl from pyhealth.tasks.base_task import BaseTask from pyhealth.tasks.bmd_hs_disease_classification import BMDHSDiseaseClassification diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py index eac1f33b8..22ca79d5c 100644 --- a/pyhealth/datasets/mimic3.py +++ b/pyhealth/datasets/mimic3.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import List, Optional -import polars as pl +import narwhals as pl from .base_dataset import BaseDataset diff --git a/pyhealth/datasets/mimic4.py b/pyhealth/datasets/mimic4.py index 05321dedb..9b48985e3 100644 --- a/pyhealth/datasets/mimic4.py +++ b/pyhealth/datasets/mimic4.py @@ -8,6 +8,7 @@ try: import psutil + HAS_PSUTIL = True except ImportError: HAS_PSUTIL = False @@ -20,7 +21,7 @@ def log_memory_usage(tag=""): """Log current memory usage if psutil is available.""" if HAS_PSUTIL: - process = psutil.Process(os.getpid()) + process = psutil.Process(os.getpid()) # type: ignore mem_info = process.memory_info() logger.info(f"Memory usage {tag}: {mem_info.rss / (1024 * 1024):.1f} MB") else: @@ -39,7 +40,7 @@ class MIMIC4EHRDataset(BaseDataset): tables (List[str]): A list of tables to be included in the dataset. dataset_name (Optional[str]): The name of the dataset. config_path (Optional[str]): The path to the configuration file. - """ + """ def __init__( self, @@ -47,10 +48,13 @@ def __init__( tables: List[str], dataset_name: str = "mimic4_ehr", config_path: Optional[str] = None, - **kwargs + cache_dir: Optional[str] = None, + **kwargs, ): if config_path is None: - config_path = os.path.join(os.path.dirname(__file__), "configs", "mimic4_ehr.yaml") + config_path = os.path.join( + os.path.dirname(__file__), "configs", "mimic4_ehr.yaml" + ) logger.info(f"Using default EHR config: {config_path}") log_memory_usage(f"Before initializing {dataset_name}") @@ -61,7 +65,8 @@ def __init__( tables=tables, dataset_name=dataset_name, config_path=config_path, - **kwargs + cache_dir=cache_dir, + **kwargs, ) log_memory_usage(f"After initializing {dataset_name}") @@ -86,10 +91,13 @@ def __init__( tables: List[str], dataset_name: str = "mimic4_note", config_path: Optional[str] = None, - **kwargs + cache_dir: Optional[str] = None, + **kwargs, ): if config_path is None: - config_path = os.path.join(os.path.dirname(__file__), "configs", "mimic4_note.yaml") + config_path = os.path.join( + os.path.dirname(__file__), "configs", "mimic4_note.yaml" + ) logger.info(f"Using default note config: {config_path}") if "discharge" in tables: warnings.warn( @@ -109,7 +117,8 @@ def __init__( tables=tables, dataset_name=dataset_name, config_path=config_path, - **kwargs + cache_dir=cache_dir, + **kwargs, ) log_memory_usage(f"After initializing {dataset_name}") @@ -134,10 +143,13 @@ def __init__( tables: List[str], dataset_name: str = "mimic4_cxr", config_path: Optional[str] = None, - **kwargs + cache_dir: Optional[str] = None, + **kwargs, ): if config_path is None: - config_path = os.path.join(os.path.dirname(__file__), "configs", "mimic4_cxr.yaml") + config_path = os.path.join( + os.path.dirname(__file__), "configs", "mimic4_cxr.yaml" + ) logger.info(f"Using default CXR config: {config_path}") self.prepare_metadata(root) log_memory_usage(f"Before initializing {dataset_name}") @@ -146,12 +158,15 @@ def __init__( tables=tables, dataset_name=dataset_name, config_path=config_path, - **kwargs + cache_dir=cache_dir, + **kwargs, ) log_memory_usage(f"After initializing {dataset_name}") def prepare_metadata(self, root: str) -> None: - metadata = pd.read_csv(os.path.join(root, "mimic-cxr-2.0.0-metadata.csv.gz"), dtype=str) + metadata = pd.read_csv( + os.path.join(root, "mimic-cxr-2.0.0-metadata.csv.gz"), dtype=str + ) def process_studytime(x): # reformat studytime to be 6 digits (e.g. 123.002 -> 000123 which is 12:30:00) @@ -160,6 +175,7 @@ def process_studytime(x): return f"{int(x):06d}" except Exception: return x + metadata["StudyTime"] = metadata["StudyTime"].apply(process_studytime) def process_image_path(x): @@ -168,10 +184,15 @@ def process_image_path(x): folder = subject_id[:3] study_id = "s" + x["study_id"] dicom_id = x["dicom_id"] - return os.path.join(root, "files", folder, subject_id, study_id, f"{dicom_id}.jpg") + return os.path.join( + root, "files", folder, subject_id, study_id, f"{dicom_id}.jpg" + ) + metadata["image_path"] = metadata.apply(process_image_path, axis=1) - metadata.to_csv(os.path.join(root, "mimic-cxr-2.0.0-metadata-pyhealth.csv"), index=False) + metadata.to_csv( + os.path.join(root, "mimic-cxr-2.0.0-metadata-pyhealth.csv"), index=False + ) return @@ -211,19 +232,10 @@ def __init__( cxr_config_path: Optional[str] = None, dataset_name: str = "mimic4", dev: bool = False, + cache_dir: Optional[str] = None, ): log_memory_usage("Starting MIMIC4Dataset init") - # Initialize child datasets - self.dataset_name = dataset_name - self.sub_datasets = {} - self.root = None - self.tables = None - self.config = None - # Dev flag is only used in the MIMIC4Dataset class - # to ensure the same set of patients are used for all sub-datasets. - self.dev = dev - # We need at least one root directory if not any([ehr_root, note_root, cxr_root]): raise ValueError("At least one root directory must be provided") @@ -233,48 +245,66 @@ def __init__( note_tables = note_tables or [] cxr_tables = cxr_tables or [] + super().__init__( + root=f"{ehr_root}|{note_root}|{cxr_root}", + tables=ehr_tables + note_tables + cxr_tables, + dataset_name=dataset_name, + config_path=None, + dev=dev, + cache_dir=cache_dir, + ) + + # Initialize child datasets + self.sub_datasets: dict[str, BaseDataset] = {} + # Initialize EHR dataset if root is provided if ehr_root: - logger.info(f"Initializing MIMIC4EHRDataset with tables: {ehr_tables} (dev mode: {dev})") + logger.info( + f"Initializing MIMIC4EHRDataset with tables: {ehr_tables} (dev mode: {dev})" + ) + ehr_cache_dir = None if cache_dir is None else f"{cache_dir}/ehr" self.sub_datasets["ehr"] = MIMIC4EHRDataset( root=ehr_root, tables=ehr_tables, config_path=ehr_config_path, + cache_dir=ehr_cache_dir, + dev=dev, ) log_memory_usage("After EHR dataset initialization") # Initialize Notes dataset if root is provided if note_root is not None and note_tables: - logger.info(f"Initializing MIMIC4NoteDataset with tables: {note_tables} (dev mode: {dev})") + logger.info( + f"Initializing MIMIC4NoteDataset with tables: {note_tables} (dev mode: {dev})" + ) + note_cache_dir = None if cache_dir is None else f"{cache_dir}/note" self.sub_datasets["note"] = MIMIC4NoteDataset( root=note_root, tables=note_tables, config_path=note_config_path, + cache_dir=note_cache_dir, + dev=dev, ) log_memory_usage("After Note dataset initialization") # Initialize CXR dataset if root is provided if cxr_root is not None: - logger.info(f"Initializing MIMIC4CXRDataset with tables: {cxr_tables} (dev mode: {dev})") + logger.info( + f"Initializing MIMIC4CXRDataset with tables: {cxr_tables} (dev mode: {dev})" + ) + cxr_cache_dir = None if cache_dir is None else f"{cache_dir}/cxr" self.sub_datasets["cxr"] = MIMIC4CXRDataset( root=cxr_root, tables=cxr_tables, config_path=cxr_config_path, + cache_dir=cxr_cache_dir, + dev=dev, ) log_memory_usage("After CXR dataset initialization") - # Combine data from all sub-datasets - log_memory_usage("Before combining data") - self.global_event_df = self._combine_data() - log_memory_usage("After combining data") - - # Cache attributes - self._collected_global_event_df = None - self._unique_patient_ids = None - log_memory_usage("Completed MIMIC4Dataset init") - def _combine_data(self) -> pl.LazyFrame: + def load_data(self) -> pl.LazyFrame: """ Combines data from all initialized sub-datasets into a unified global event dataframe. @@ -286,7 +316,7 @@ def _combine_data(self) -> pl.LazyFrame: # Collect global event dataframes from all sub-datasets for dataset_type, dataset in self.sub_datasets.items(): logger.info(f"Combining data from {dataset_type} dataset") - frames.append(dataset.global_event_df) + frames.append(dataset.load_data()) # Concatenate all frames logger.info("Creating combined dataframe") diff --git a/pyhealth/datasets/omop.py b/pyhealth/datasets/omop.py index 675210c14..bb9b0f1ed 100644 --- a/pyhealth/datasets/omop.py +++ b/pyhealth/datasets/omop.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import List, Optional -import polars as pl +import narwhals as pl from .base_dataset import BaseDataset @@ -155,15 +155,19 @@ def preprocess_person(self, df: pl.LazyFrame) -> pl.LazyFrame: df = df.with_columns( [ ( - pl.col("year_of_birth").cast(pl.Utf8) + pl.col("year_of_birth").cast(pl.String) + "-" - + pl.when(pl.col("month_of_birth").is_null()) - .then(pl.lit("01")) - .otherwise(pl.col("month_of_birth").cast(pl.Utf8).str.zfill(2)) + + ( + pl.when(pl.col("month_of_birth").is_null()) + .then(pl.lit("01")) + .otherwise(pl.col("month_of_birth").cast(pl.String).str.zfill(2)) + ).cast(pl.String) + "-" - + pl.when(pl.col("day_of_birth").is_null()) - .then(pl.lit("01")) - .otherwise(pl.col("day_of_birth").cast(pl.Utf8).str.zfill(2)) + + ( + pl.when(pl.col("day_of_birth").is_null()) + .then(pl.lit("01")) + .otherwise(pl.col("day_of_birth").cast(pl.String).str.zfill(2)) + ).cast(pl.String) + " 00:00:00" ).alias("birth_datetime") ] diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 54c40420c..e14e02ca9 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -1,172 +1,283 @@ -from typing import Any, Dict, List, Optional, Tuple, Union, Type +from collections.abc import Sequence +from pathlib import Path +import pickle +import tempfile +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type import inspect - -from torch.utils.data import Dataset -from tqdm import tqdm +import random +from bisect import bisect_right +import litdata +from litdata.utilities.train_test_split import deepcopy_dataset +import copy from ..processors import get_processor from ..processors.base_processor import FeatureProcessor -class SampleDataset(Dataset): - """Sample dataset class for handling and processing data samples. +class SampleBuilder: + """Fit feature processors and transform pickled samples without materializing a dataset. - Attributes: - samples (List[Dict]): List of data samples. - input_schema (Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]]): - Schema for input data. Values can be string aliases, processor classes, processor instances, or tuples of (spec, kwargs_dict). - output_schema (Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]]): - Schema for output data. Values can be string aliases, processor classes, processor instances, or tuples of (spec, kwargs_dict). - dataset_name (Optional[str]): Name of the dataset. - task_name (Optional[str]): Name of the task. + SampleBuilder is a lightweight helper used to: + - Fit feature processors from provided `input_schema` and `output_schema` on an + iterable of raw Python sample dictionaries. + - Build mappings from patient IDs and record IDs to sample indices. + - Transform pickled sample records into processed feature dictionaries using + the fitted processors. + + Typical usage: + builder = SampleBuilder(input_schema, output_schema) + builder.fit(samples) + builder.save(path) # writes a schema.pkl metadata file + + After saving the schema, `litdata.optimize` can be used with `builder.transform` + to serialize and chunk pickled sample items into a directory that can be + loaded via SampleDataset. """ def __init__( self, - samples: List[Dict], - input_schema: Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]], - output_schema: Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]], - dataset_name: Optional[str] = None, - task_name: Optional[str] = None, + input_schema: Dict[str, Any], + output_schema: Dict[str, Any], input_processors: Optional[Dict[str, FeatureProcessor]] = None, output_processors: Optional[Dict[str, FeatureProcessor]] = None, ) -> None: - """Initializes the SampleDataset with samples and schemas. - - Args: - samples (List[Dict]): List of data samples. - input_schema (Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]]): - Schema for input data. Values can be string aliases, processor classes, processor instances, or tuples of (spec, kwargs_dict) for instantiation. - output_schema (Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]]): - Schema for output data. Values can be string aliases, processor classes, processor instances, or tuples of (spec, kwargs_dict) for instantiation. - dataset_name (Optional[str], optional): Name of the dataset. - Defaults to None. - task_name (Optional[str], optional): Name of the task. - Defaults to None. - input_processors (Optional[Dict[str, FeatureProcessor]], - optional): Pre-fitted input processors. If provided, these - will be used instead of creating new ones from input_schema. - Defaults to None. - output_processors (Optional[Dict[str, FeatureProcessor]], - optional): Pre-fitted output processors. If provided, these - will be used instead of creating new ones from output_schema. - Defaults to None. - """ - if dataset_name is None: - dataset_name = "" - if task_name is None: - task_name = "" - self.samples = samples self.input_schema = input_schema self.output_schema = output_schema - self.input_processors = input_processors if input_processors is not None else {} - self.output_processors = ( + self._input_processors = ( + input_processors if input_processors is not None else {} + ) + self._output_processors = ( output_processors if output_processors is not None else {} ) - self.dataset_name = dataset_name - self.task_name = task_name - # Create patient_to_index and record_to_index mappings - self.patient_to_index = {} - self.record_to_index = {} + self._patient_to_index: Dict[str, List[int]] = {} + self._record_to_index: Dict[str, List[int]] = {} + self._fitted = False - for i, sample in enumerate(samples): - # Create patient_to_index mapping - patient_id = sample.get("patient_id") - if patient_id is not None: - if patient_id not in self.patient_to_index: - self.patient_to_index[patient_id] = [] - self.patient_to_index[patient_id].append(i) - - # Create record_to_index mapping (optional) - record_id = sample.get("record_id", sample.get("visit_id")) - if record_id is not None: - if record_id not in self.record_to_index: - self.record_to_index[record_id] = [] - self.record_to_index[record_id].append(i) + @property + def input_processors(self) -> Dict[str, FeatureProcessor]: + if not self._fitted: + raise RuntimeError( + "SampleBuilder.fit must be called before accessing input_processors." + ) + return self._input_processors - self.validate() - self.build() + @property + def output_processors(self) -> Dict[str, FeatureProcessor]: + if not self._fitted: + raise RuntimeError( + "SampleBuilder.fit must be called before accessing output_processors." + ) + return self._output_processors - def _get_processor_instance(self, processor_spec): - """Get processor instance from either string alias, class reference, processor instance, or tuple with kwargs. + @property + def patient_to_index(self) -> Dict[str, List[int]]: + if not self._fitted: + raise RuntimeError( + "SampleBuilder.fit must be called before accessing patient_to_index." + ) + return self._patient_to_index - Args: - processor_spec: Either a string alias, a processor class, a processor instance, or a tuple (spec, kwargs_dict) + @property + def record_to_index(self) -> Dict[str, List[int]]: + if not self._fitted: + raise RuntimeError( + "SampleBuilder.fit must be called before accessing record_to_index." + ) + return self._record_to_index - Returns: - Instance of the processor - """ + def _get_processor_instance(self, processor_spec): + """Instantiate a processor using the same resolution logic as SampleDataset.""" if isinstance(processor_spec, tuple): spec, kwargs = processor_spec if isinstance(spec, str): return get_processor(spec)(**kwargs) - elif inspect.isclass(spec) and issubclass(spec, FeatureProcessor): + if inspect.isclass(spec) and issubclass(spec, FeatureProcessor): return spec(**kwargs) - else: - raise ValueError( - f"Processor spec in tuple must be either a string alias or a " - f"FeatureProcessor class, got {type(spec)}" - ) + raise ValueError( + "Processor spec in tuple must be either a string alias or a " + f"FeatureProcessor class, got {type(spec)}" + ) if isinstance(processor_spec, str): - # Use existing registry system for string aliases return get_processor(processor_spec)() - elif inspect.isclass(processor_spec) and issubclass( + if inspect.isclass(processor_spec) and issubclass( processor_spec, FeatureProcessor ): - # Direct class reference return processor_spec() - elif isinstance(processor_spec, FeatureProcessor): - # Already an instance + if isinstance(processor_spec, FeatureProcessor): return processor_spec - else: - raise ValueError( - f"Processor spec must be either a string alias, a " - f"FeatureProcessor class, or a tuple (spec, kwargs_dict), got {type(processor_spec)}" - ) + raise ValueError( + "Processor spec must be either a string alias, a FeatureProcessor " + f"class, or a tuple (spec, kwargs_dict), got {type(processor_spec)}" + ) + + def fit( + self, + samples: Iterable[Dict[str, Any]], + ) -> None: + """Fit processors and build mapping indices from an iterator of samples. + + Args: + samples: Iterable of sample dictionaries (e.g., python dicts). Each + sample should contain keys covering both the configured + `input_schema` and `output_schema`. These samples are not + required to be pickled; `fit` operates on in-memory dicts. - def validate(self) -> None: - """Validates that the samples match the input and output schemas.""" + Behavior: + - Validates the samples contain all keys specified by the input + and output schemas. + - Builds `patient_to_index` and `record_to_index` mappings by + recording the sample indices associated with `patient_id` and + `record_id`/`visit_id` fields. + - Instantiates and fits input/output processors from the provided + schemas (unless pre-fitted processors were supplied to the + constructor). + """ + # Validate the samples input_keys = set(self.input_schema.keys()) output_keys = set(self.output_schema.keys()) - for s in self.samples: - assert input_keys.issubset(s.keys()), "Input schema does not match samples." - assert output_keys.issubset(s.keys()), ( - "Output schema does not match samples." - ) - return - - def build(self) -> None: - """Builds the processors for input and output data based on schemas.""" - # Only fit if processors weren't provided - if not self.input_processors: - for k, v in self.input_schema.items(): - self.input_processors[k] = self._get_processor_instance(v) - self.input_processors[k].fit(self.samples, k) - if not self.output_processors: - for k, v in self.output_schema.items(): - self.output_processors[k] = self._get_processor_instance(v) - self.output_processors[k].fit(self.samples, k) - # Always process samples with the (fitted) processors - for sample in tqdm(self.samples, desc="Processing samples"): - for k, v in sample.items(): - if k in self.input_processors: - sample[k] = self.input_processors[k].process(v) - elif k in self.output_processors: - sample[k] = self.output_processors[k].process(v) - return - - def __getitem__(self, index: int) -> Dict: - """Returns a sample by index. + for sample in samples: + assert input_keys.issubset( + sample.keys() + ), "Input schema does not match samples." + assert output_keys.issubset( + sample.keys() + ), "Output schema does not match samples." + + # Build index mappings + self._patient_to_index = {} + self._record_to_index = {} + for i, sample in enumerate(samples): + patient_id = sample.get("patient_id") + if patient_id is not None: + self._patient_to_index.setdefault(patient_id, []).append(i) + record_id = sample.get("record_id", sample.get("visit_id")) + if record_id is not None: + self._record_to_index.setdefault(record_id, []).append(i) + + # Fit processors if they were not provided + if not self._input_processors: + for key, spec in self.input_schema.items(): + processor = self._get_processor_instance(spec) + processor.fit(samples, key) + self._input_processors[key] = processor + if not self._output_processors: + for key, spec in self.output_schema.items(): + processor = self._get_processor_instance(spec) + processor.fit(samples, key) + self._output_processors[key] = processor + + self._fitted = True + + def transform(self, sample: dict[str, bytes]) -> Dict[str, Any]: + """Transform a single serialized (pickled) sample using fitted processors. Args: - index (int): Index of the sample to retrieve. + sample: A mapping with a single key `"sample"` whose value is a + pickled Python dictionary (produced by `pickle.dumps`). The + pickled dictionary should mirror the schema that was used to + fit this builder. Returns: - Dict: A dict with patient_id, visit_id/record_id, and other - task-specific attributes as key. Conversion to index/tensor - will be done in the model. + A Python dictionary where each key is either an input or output + feature name. Values for keys present in the corresponding fitted + processors have been processed through their FeatureProcessor and + are returned as the output of that processor. Keys not covered by + the input/output processors are returned unchanged. + """ + if not self._fitted: + raise RuntimeError("SampleBuilder.fit must be called before transform().") + + transformed: Dict[str, Any] = {} + for key, value in pickle.loads(sample["sample"]).items(): + if key in self._input_processors: + transformed[key] = self._input_processors[key].process(value) + elif key in self._output_processors: + transformed[key] = self._output_processors[key].process(value) + else: + transformed[key] = value + return transformed + + def save(self, path: str) -> None: + """Save fitted metadata to the given path as a pickled file. + + Args: + path: Location where the builder will write a pickled metadata file + (commonly named `schema.pkl`). The saved metadata contains + the fitted input/output schemas, processors, and index + mappings. This file is read by `SampleDataset` during + construction. + """ + if not self._fitted: + raise RuntimeError("SampleBuilder.fit must be called before save().") + metadata = { + "input_schema": self.input_schema, + "output_schema": self.output_schema, + "input_processors": self._input_processors, + "output_processors": self._output_processors, + "patient_to_index": self._patient_to_index, + "record_to_index": self._record_to_index, + } + with open(path, "wb") as f: + pickle.dump(metadata, f) + + +class SampleDataset(litdata.StreamingDataset): + """A streaming dataset that loads sample metadata and processors from disk. + + SampleDataset expects the `path` directory to contain a `schema.pkl` + file created by a `SampleBuilder.save(...)` call. The `schema.pkl` must + include the fitted `input_schema`, `output_schema`, `input_processors`, + `output_processors`, `patient_to_index` and `record_to_index` mappings. + + Attributes: + input_schema: The configuration used to instantiate processors for + input features (string aliases or processor specs). + output_schema: The configuration used to instantiate processors for + output features. + input_processors: A mapping of input feature names to fitted + FeatureProcessor instances. + output_processors: A mapping of output feature names to fitted + FeatureProcessor instances. + patient_to_index: Dictionary mapping patient IDs to the list of + sample indices associated with that patient. + record_to_index: Dictionary mapping record/visit IDs to the list of + sample indices associated with that record. + dataset_name: Optional human friendly dataset name. + task_name: Optional human friendly task name. + """ + + def __init__( + self, + path: str, + dataset_name: Optional[str] = None, + task_name: Optional[str] = None, + **kwargs, + ) -> None: + """Initialize a SampleDataset pointing at a directory created by SampleBuilder. + + Args: + path: Path to a directory containing a `schema.pkl` produced by + `SampleBuilder.save` and associated pickled sample files. + dataset_name: Optional human-friendly dataset name. + task_name: Optional human-friendly task name. + **kwargs: Extra keyword arguments forwarded to + `litdata.StreamingDataset` (such as streaming options). """ - return self.samples[index] + super().__init__(path, **kwargs) + + self.dataset_name = "" if dataset_name is None else dataset_name + self.task_name = "" if task_name is None else task_name + + with open(f"{path}/schema.pkl", "rb") as f: + metadata = pickle.load(f) + + self.input_schema = metadata["input_schema"] + self.output_schema = metadata["output_schema"] + self.input_processors = metadata["input_processors"] + self.output_processors = metadata["output_processors"] + + self.patient_to_index = metadata["patient_to_index"] + self.record_to_index = metadata["record_to_index"] def __str__(self) -> str: """Returns a string representation of the dataset. @@ -176,10 +287,254 @@ def __str__(self) -> str: """ return f"Sample dataset {self.dataset_name} {self.task_name}" + def subset(self, indices: Union[Sequence[int], slice]) -> "SampleDataset": + """Create a StreamingDataset restricted to the provided indices.""" + + new_dataset = deepcopy_dataset(self) + + if len(new_dataset.subsampled_files) != len(new_dataset.region_of_interest): + raise ValueError( + "The provided dataset has mismatched subsampled_files and region_of_interest lengths." + ) + + dataset_length = sum( + end - start for start, end in new_dataset.region_of_interest + ) + + if isinstance(indices, slice): + indices = range(*indices.indices(dataset_length)) + + if any(idx < 0 or idx >= dataset_length for idx in indices): + raise ValueError( + f"Subset indices must be in [0, {dataset_length - 1}] for the provided dataset." + ) + + # Build chunk boundaries so we can translate global indices into + # chunk-local (start, end) pairs that litdata understands. + chunk_starts: List[int] = [] + chunk_boundaries: List[Tuple[str, int, int, int, int]] = [] + cursor = 0 + for filename, (roi_start, roi_end) in zip( + new_dataset.subsampled_files, new_dataset.region_of_interest + ): + chunk_len = roi_end - roi_start + if chunk_len <= 0: + continue + chunk_starts.append(cursor) + chunk_boundaries.append( + (filename, roi_start, roi_end, cursor, cursor + chunk_len) + ) + cursor += chunk_len + + new_subsampled_files: List[str] = [] + new_roi: List[Tuple[int, int]] = [] + prev_chunk_idx: Optional[int] = None + + for idx in indices: + chunk_idx = bisect_right(chunk_starts, idx) - 1 + if chunk_idx < 0 or idx >= chunk_boundaries[chunk_idx][4]: + raise ValueError(f"Index {idx} is out of bounds for the dataset.") + + filename, roi_start, _, global_start, _ = chunk_boundaries[chunk_idx] + offset_in_chunk = roi_start + (idx - global_start) + + if ( + new_roi + and prev_chunk_idx == chunk_idx + and offset_in_chunk == new_roi[-1][1] + ): + new_roi[-1] = (new_roi[-1][0], new_roi[-1][1] + 1) + else: + new_subsampled_files.append(filename) + new_roi.append((offset_in_chunk, offset_in_chunk + 1)) + + prev_chunk_idx = chunk_idx + + new_dataset.subsampled_files = new_subsampled_files + new_dataset.region_of_interest = new_roi + new_dataset.reset() + + return new_dataset + + +class InMemorySampleDataset(SampleDataset): + """A SampleDataset that loads all samples into memory for fast access. + + InMemorySampleDataset extends SampleDataset by eagerly loading and + transforming all samples into memory during initialization. This allows + for fast, repeated access to samples without disk I/O, at the cost of + higher memory usage. + + Note: + This class is intended for testing and debugging purposes where + dataset sizes are small enough to fit into memory. + """ + + def __init__( + self, + samples: List[Dict[str, Any]], + input_schema: Dict[str, Any], + output_schema: Dict[str, Any], + dataset_name: Optional[str] = None, + task_name: Optional[str] = None, + input_processors: Optional[Dict[str, FeatureProcessor]] = None, + output_processors: Optional[Dict[str, FeatureProcessor]] = None, + ) -> None: + """Initialize an InMemorySampleDataset from in-memory samples. + + This constructor fits a SampleBuilder on the provided samples, + transforms all samples into memory, and sets up the dataset attributes. + + Args: + samples: A list of sample dictionaries (in-memory). + input_schema: Schema describing how input keys should be handled. + output_schema: Schema describing how output keys should be handled. + dataset_name: Optional human-friendly dataset name. + task_name: Optional human-friendly task name. + input_processors: Optional pre-fitted input processors to use instead + of creating new ones from the input_schema. + output_processors: Optional pre-fitted output processors to use + instead of creating new ones from the output_schema. + """ + builder = SampleBuilder( + input_schema=input_schema, + output_schema=output_schema, + input_processors=input_processors, + output_processors=output_processors, + ) + builder.fit(samples) + + self.dataset_name = "" if dataset_name is None else dataset_name + self.task_name = "" if task_name is None else task_name + + self.input_schema = builder.input_schema + self.output_schema = builder.output_schema + self.input_processors = builder.input_processors + self.output_processors = builder.output_processors + + self.patient_to_index = builder.patient_to_index + self.record_to_index = builder.record_to_index + + self._data = [builder.transform({"sample": pickle.dumps(s)}) for s in samples] + + self._shuffle = False + + def set_shuffle(self, shuffle: bool) -> None: + self._shuffle = shuffle + def __len__(self) -> int: """Returns the number of samples in the dataset. Returns: - int: The number of samples. + int: The total number of samples. + """ + return len(self._data) + + def __getitem__(self, index: int) -> Dict[str, Any]: # type: ignore + """Retrieve a processed sample by index. + + Args: + index: The index of the sample to retrieve. + + Returns: + A dictionary containing processed input and output features. + """ + return self._data[index] + + def __iter__(self) -> Iterable[Dict[str, Any]]: # type: ignore + """Returns an iterator over all samples in the dataset. + + Returns: + An iterator yielding processed sample dictionaries. """ - return len(self.samples) + if self._shuffle: + shuffled_data = self._data[:] + random.shuffle(shuffled_data) + return iter(shuffled_data) + else: + return iter(self._data) + + def subset(self, indices: Union[Sequence[int], slice]) -> SampleDataset: + if isinstance(indices, slice): + samples = self._data[indices] + else: + samples = [self._data[i] for i in indices] + + new_dataset = copy.deepcopy(self) + new_dataset._data = samples + return new_dataset + + +def create_sample_dataset( + samples: List[Dict[str, Any]], + input_schema: Dict[str, Any], + output_schema: Dict[str, Any], + dataset_name: Optional[str] = None, + task_name: Optional[str] = None, + input_processors: Optional[Dict[str, FeatureProcessor]] = None, + output_processors: Optional[Dict[str, FeatureProcessor]] = None, + in_memory: bool = True, +): + """Convenience helper to create an on-disk SampleDataset from in-memory samples. + + This helper will: + - Create a temporary directory for the dataset output. + - Fit a `SampleBuilder` with the provided schemas and samples. + - Save the fitted `schema.pkl` to the temporary directory. + - Use `litdata.optimize` with `builder.transform` to write serialized + and chunked sample files into the directory. + - Return a `SampleDataset` instance pointed at the temporary directory. + + Args: + samples: A list of Python dictionaries representing raw samples. + input_schema: Schema describing how input keys should be handled. + output_schema: Schema describing how output keys should be handled. + dataset_name: Optional dataset name to attach to the returned + SampleDataset instance. + task_name: Optional task name to attach to the returned SampleDataset + instance. + input_processors: Optional pre-fitted input processors to use instead + of creating new ones from the input_schema. + output_processors: Optional pre-fitted output processors to use + instead of creating new ones from the output_schema. + in_memory: If True, returns an InMemorySampleDataset instead of + a disk-backed SampleDataset. + + Returns: + An instance of `SampleDataset` loaded from the temporary directory + containing the optimized, chunked samples and `schema.pkl` metadata. + """ + if in_memory: + return InMemorySampleDataset( + samples=samples, + input_schema=input_schema, + output_schema=output_schema, + dataset_name=dataset_name, + task_name=task_name, + input_processors=input_processors, + output_processors=output_processors, + ) + else: + path = Path(tempfile.mkdtemp()) + + builder = SampleBuilder( + input_schema=input_schema, # type: ignore + output_schema=output_schema, # type: ignore + input_processors=input_processors, + output_processors=output_processors, + ) + builder.fit(samples) + builder.save(str(path / "schema.pkl")) + litdata.optimize( + fn=builder.transform, + inputs=[{"sample": pickle.dumps(x)} for x in samples], + output_dir=str(path), + chunk_bytes="64MB", + num_workers=0, + ) + + return SampleDataset( + path=str(path), + dataset_name=dataset_name, + task_name=task_name, + ) diff --git a/pyhealth/datasets/splitter.py b/pyhealth/datasets/splitter.py index c70df5660..cbaaca7c0 100644 --- a/pyhealth/datasets/splitter.py +++ b/pyhealth/datasets/splitter.py @@ -40,9 +40,9 @@ def split_by_visit( int(len(dataset) * ratios[0]) : int(len(dataset) * (ratios[0] + ratios[1])) ] test_index = index[int(len(dataset) * (ratios[0] + ratios[1])) :] - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = dataset.subset(train_index) # type: ignore + val_dataset = dataset.subset(val_index) # type: ignore + test_dataset = dataset.subset(test_index) # type: ignore return train_dataset, val_dataset, test_dataset @@ -82,9 +82,9 @@ def split_by_patient( ) val_index = list(chain(*[dataset.patient_to_index[i] for i in val_patient_indx])) test_index = list(chain(*[dataset.patient_to_index[i] for i in test_patient_indx])) - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = dataset.subset(train_index) # type: ignore + val_dataset = dataset.subset(val_index) # type: ignore + test_dataset = dataset.subset(test_index) # type: ignore return train_dataset, val_dataset, test_dataset @@ -119,9 +119,9 @@ def split_by_sample( int(len(dataset) * ratios[0]) : int(len(dataset) * (ratios[0] + ratios[1])) ] test_index = index[int(len(dataset) * (ratios[0] + ratios[1])) :] - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = dataset.subset(train_index) # type: ignore + val_dataset = dataset.subset(val_index) # type: ignore + test_dataset = dataset.subset(test_index) # type: ignore if get_index: return ( @@ -172,10 +172,10 @@ def split_by_visit_conformal( cal_index = index[val_end:cal_end] test_index = index[cal_end:] - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - cal_dataset = torch.utils.data.Subset(dataset, cal_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = dataset.subset(train_index) # type: ignore + val_dataset = dataset.subset(val_index) # type: ignore + cal_dataset = dataset.subset(cal_index) # type: ignore + test_dataset = dataset.subset(test_index) # type: ignore return train_dataset, val_dataset, cal_dataset, test_dataset @@ -227,10 +227,10 @@ def split_by_patient_conformal( cal_index = list(chain(*[dataset.patient_to_index[i] for i in cal_patient_indx])) test_index = list(chain(*[dataset.patient_to_index[i] for i in test_patient_indx])) - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - cal_dataset = torch.utils.data.Subset(dataset, cal_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = dataset.subset(train_index) # type: ignore + val_dataset = dataset.subset(val_index) # type: ignore + cal_dataset = dataset.subset(cal_index) # type: ignore + test_dataset = dataset.subset(test_index) # type: ignore return train_dataset, val_dataset, cal_dataset, test_dataset @@ -285,8 +285,8 @@ def split_by_sample_conformal( torch.tensor(test_index), ) else: - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - cal_dataset = torch.utils.data.Subset(dataset, cal_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = dataset.subset(train_index) # type: ignore + val_dataset = dataset.subset(val_index) # type: ignore + cal_dataset = dataset.subset(cal_index) # type: ignore + test_dataset = dataset.subset(test_index) # type: ignore return train_dataset, val_dataset, cal_dataset, test_dataset diff --git a/pyhealth/datasets/utils.py b/pyhealth/datasets/utils.py index 63ca4152a..af7babe19 100644 --- a/pyhealth/datasets/utils.py +++ b/pyhealth/datasets/utils.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple import torch +import litdata from dateutil.parser import parse as dateutil_parse from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader @@ -319,7 +320,7 @@ def collate_fn_dict_with_padding(batch: List[dict]) -> dict: def get_dataloader( - dataset: torch.utils.data.Dataset, batch_size: int, shuffle: bool = False + dataset: litdata.StreamingDataset, batch_size: int, shuffle: bool = False ) -> DataLoader: """Creates a DataLoader for a given dataset. @@ -331,10 +332,10 @@ def get_dataloader( Returns: A DataLoader instance for the dataset. """ + dataset.set_shuffle(shuffle) dataloader = DataLoader( dataset, batch_size=batch_size, - shuffle=shuffle, collate_fn=collate_fn_dict_with_padding, ) diff --git a/pyhealth/models/adacare.py b/pyhealth/models/adacare.py index 730ff36ec..5e4ee806b 100644 --- a/pyhealth/models/adacare.py +++ b/pyhealth/models/adacare.py @@ -289,7 +289,7 @@ class AdaCare(BaseModel): Examples: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", @@ -321,7 +321,7 @@ class AdaCare(BaseModel): ... "label": 0, ... }, ... ] - >>> dataset = SampleDataset(samples=samples, + >>> dataset = create_sample_dataset(samples=samples, ... input_schema={ ... 'vector': 'nested_sequence_floats', ... 'list_codes': 'sequence', @@ -329,7 +329,8 @@ class AdaCare(BaseModel): ... 'list_vectors': 'nested_sequence_floats', ... 'list_list_vectors': 'deep_nested_sequence_floats' ... }, - ... output_schema={'label': 'binary'} + ... output_schema={'label': 'binary'}, + ... dataset_name='test' ... ) >>> >>> from pyhealth.models import AdaCare diff --git a/pyhealth/models/cnn.py b/pyhealth/models/cnn.py index 74d5fd522..9e69de3bd 100644 --- a/pyhealth/models/cnn.py +++ b/pyhealth/models/cnn.py @@ -159,7 +159,7 @@ class CNN(BaseModel): **kwargs: Additional keyword arguments forwarded to :class:`CNNLayer`. Example: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "p0", @@ -176,7 +176,7 @@ class CNN(BaseModel): ... "label": 0, ... }, ... ] - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={"conditions": "sequence", "labs": "tensor"}, ... output_schema={"label": "binary"}, @@ -262,7 +262,7 @@ def _determine_input_channels(self, feature_key: str, spatial_dim: int) -> int: return self.embedding_dim min_dim = spatial_dim + 1 - for sample in self.dataset.samples: + for sample in self.dataset: if feature_key not in sample: continue feature = self._extract_feature_tensor(sample[feature_key]) @@ -326,7 +326,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": - from pyhealth.datasets import SampleDataset + from pyhealth.datasets import create_sample_dataset from pyhealth.datasets import get_dataloader samples = [ @@ -348,7 +348,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: input_schema = {"conditions": "sequence", "labs": "tensor"} output_schema = {"label": "binary"} - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, # type: ignore[arg-type] output_schema=output_schema, # type: ignore[arg-type] diff --git a/pyhealth/models/contrawr.py b/pyhealth/models/contrawr.py index 227e8b1b5..6caa139d3 100644 --- a/pyhealth/models/contrawr.py +++ b/pyhealth/models/contrawr.py @@ -173,7 +173,7 @@ def __init__( self.fc = nn.Linear(emb_size, output_size) def _determine_input_channels_length(self) -> int: - for sample in self.dataset.samples: + for sample in self.dataset: if self.feature_keys[0] not in sample: continue diff --git a/pyhealth/models/embedding.py b/pyhealth/models/embedding.py index 3b5a39d3c..a3aad8244 100644 --- a/pyhealth/models/embedding.py +++ b/pyhealth/models/embedding.py @@ -100,7 +100,7 @@ def __init__(self, dataset: SampleDataset, embedding_dim: int = 128): elif isinstance(processor, TensorProcessor): # Infer size from first sample sample_tensor = None - for sample in dataset.samples: + for sample in dataset: if field_name in sample: sample_tensor = processor.process(sample[field_name]) break diff --git a/pyhealth/models/gamenet.py b/pyhealth/models/gamenet.py index feacb59f9..46afe057f 100644 --- a/pyhealth/models/gamenet.py +++ b/pyhealth/models/gamenet.py @@ -241,14 +241,14 @@ class GAMENet(BaseModel): **kwargs: other parameters for the GAMENet layer. Examples: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> from pyhealth.tasks import drug_recommendation_mimic3_fn >>> from pyhealth.models import GAMENet >>> from pyhealth.datasets import split_by_patient, get_dataloader >>> from pyhealth.trainer import Trainer >>> >>> # Load MIMIC-III dataset - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples=[ ... { ... "patient_id": "patient-0", @@ -360,7 +360,7 @@ def generate_ehr_adj(self) -> torch.tensor: label_vocab = self.dataset.output_processors[self.label_key].label_vocab label_size = len(label_vocab) ehr_adj = torch.zeros((label_size, label_size)) - for sample in self.dataset.samples: + for sample in self.dataset: curr_drugs = sample["drugs"] if isinstance(curr_drugs, torch.Tensor): continue diff --git a/pyhealth/models/gnn.py b/pyhealth/models/gnn.py index c4a0eed00..9ef323cb3 100644 --- a/pyhealth/models/gnn.py +++ b/pyhealth/models/gnn.py @@ -324,9 +324,29 @@ class GCN(BaseModel): num_layers: Number of GCN layers. Defaults to 2. Examples: - >>> from pyhealth.datasets import SampleDataset, get_dataloader - >>> samples = [...] - >>> dataset = SampleDataset(...) + >>> from pyhealth.datasets import create_sample_dataset, get_dataloader + >>> samples = [ + ... { + ... "patient_id": "patient-0", + ... "visit_id": "visit-0", + ... "diagnoses": ["A", "B", "C"], + ... "procedures": ["X", "Y"], + ... "label": 1, + ... }, + ... { + ... "patient_id": "patient-1", + ... "visit_id": "visit-0", + ... "diagnoses": ["D", "E"], + ... "procedures": ["Z"], + ... "label": 0, + ... }, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"diagnoses": "sequence", "procedures": "sequence"}, + ... output_schema={"label": "binary"}, + ... dataset_name="test", + ... ) >>> model = GCN(dataset=dataset) >>> loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> batch = next(iter(loader)) @@ -503,9 +523,29 @@ class GAT(BaseModel): num_layers: Number of GAT layers. Defaults to 2. Examples: - >>> from pyhealth.datasets import SampleDataset, get_dataloader - >>> samples = [...] - >>> dataset = SampleDataset(...) + >>> from pyhealth.datasets import create_sample_dataset, get_dataloader + >>> samples = [ + ... { + ... "patient_id": "patient-0", + ... "visit_id": "visit-0", + ... "diagnoses": ["A", "B", "C"], + ... "procedures": ["X", "Y"], + ... "label": 1, + ... }, + ... { + ... "patient_id": "patient-1", + ... "visit_id": "visit-0", + ... "diagnoses": ["D", "E"], + ... "procedures": ["Z"], + ... "label": 0, + ... }, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"diagnoses": "sequence", "procedures": "sequence"}, + ... output_schema={"label": "binary"}, + ... dataset_name="test", + ... ) >>> model = GAT(dataset=dataset) >>> loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> batch = next(iter(loader)) @@ -666,7 +706,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": - from pyhealth.datasets import SampleDataset, get_dataloader + from pyhealth.datasets import create_sample_dataset, get_dataloader samples = [ { @@ -688,7 +728,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: input_schema = {"diagnoses": "sequence", "procedures": "sequence"} output_schema = {"label": "binary"} - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, output_schema=output_schema, diff --git a/pyhealth/models/logistic_regression.py b/pyhealth/models/logistic_regression.py index 9b99c491e..8155d101f 100644 --- a/pyhealth/models/logistic_regression.py +++ b/pyhealth/models/logistic_regression.py @@ -29,7 +29,7 @@ class LogisticRegression(BaseModel): **kwargs: other parameters (for compatibility). Examples: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", @@ -49,7 +49,7 @@ class LogisticRegression(BaseModel): >>> input_schema = {"conditions": "sequence", ... "procedures": "tensor"} >>> output_schema = {"label": "binary"} - >>> dataset = SampleDataset(samples=samples, + >>> dataset = create_sample_dataset(samples=samples, ... input_schema=input_schema, ... output_schema=output_schema, ... dataset_name="test") @@ -211,7 +211,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": - from pyhealth.datasets import SampleDataset + from pyhealth.datasets import create_sample_dataset samples = [ { @@ -238,7 +238,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: output_schema = {"label": "binary"} # binary classification # dataset - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, output_schema=output_schema, diff --git a/pyhealth/models/micron.py b/pyhealth/models/micron.py index fb9269bf7..ffee2612c 100644 --- a/pyhealth/models/micron.py +++ b/pyhealth/models/micron.py @@ -159,7 +159,7 @@ class MICRON(BaseModel): - output_schema should include 'drugs' as a multilabel/multihot feature Example: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", @@ -169,10 +169,11 @@ class MICRON(BaseModel): ... "drugs": ["metformin", "lisinopril"] ... } ... ] - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={"conditions": "sequence", "procedures": "sequence"}, - ... output_schema={"drugs": "multilabel"} + ... output_schema={"drugs": "multilabel"}, + ... dataset_name="test", ... ) >>> model = MICRON(dataset=dataset) """ diff --git a/pyhealth/models/mlp.py b/pyhealth/models/mlp.py index 413fb6f34..598b45778 100644 --- a/pyhealth/models/mlp.py +++ b/pyhealth/models/mlp.py @@ -55,7 +55,7 @@ class MLP(BaseModel): **kwargs: other parameters for the MLP layer. Examples: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", @@ -75,7 +75,7 @@ class MLP(BaseModel): >>> input_schema = {"conditions": "sequence", ... "procedures": "timeseries"} >>> output_schema = {"label": "binary"} - >>> dataset = SampleDataset(samples=samples, + >>> dataset = create_sample_dataset(samples=samples, ... input_schema=input_schema, ... output_schema=output_schema, ... dataset_name="test") @@ -364,7 +364,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": - from pyhealth.datasets import SampleDataset + from pyhealth.datasets import create_sample_dataset samples = [ { @@ -391,7 +391,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: output_schema = {"label": "binary"} # binary classification # dataset - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, output_schema=output_schema, diff --git a/pyhealth/models/molerec.py b/pyhealth/models/molerec.py index 4cf6c5a28..68561b548 100644 --- a/pyhealth/models/molerec.py +++ b/pyhealth/models/molerec.py @@ -498,7 +498,7 @@ class MoleRec(BaseModel): ... ] >>> >>> # dataset - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={ ... "conditions": "nested_sequence", diff --git a/pyhealth/models/retain.py b/pyhealth/models/retain.py index b9aac4309..b272995ec 100644 --- a/pyhealth/models/retain.py +++ b/pyhealth/models/retain.py @@ -132,7 +132,7 @@ class RETAIN(BaseModel): **kwargs: other parameters for the RETAIN layer. Examples: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", @@ -149,7 +149,7 @@ class RETAIN(BaseModel): ... "label": 0, ... }, ... ] - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={ ... "conditions": "nested_sequence", @@ -281,7 +281,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": - from pyhealth.datasets import SampleDataset + from pyhealth.datasets import create_sample_dataset samples = [ { @@ -303,7 +303,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: ] # dataset - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema={ "conditions": "nested_sequence", diff --git a/pyhealth/models/rnn.py b/pyhealth/models/rnn.py index fe6d7aff7..e89d58bd1 100644 --- a/pyhealth/models/rnn.py +++ b/pyhealth/models/rnn.py @@ -137,13 +137,50 @@ class RNN(BaseModel): Args: dataset (SampleDataset): the dataset to train the model. It is used to query certain information such as the set of all tokens. - feature_keys (List[str]): list of keys in samples to use as features, - e.g. ["conditions", "procedures"]. - label_key (str): key in samples to use as label (e.g., "drugs"). - mode (str): one of "binary", "multiclass", or "multilabel". embedding_dim (int): the embedding dimension. Default is 128. hidden_dim (int): the hidden dimension. Default is 128. - **kwargs: other parameters for the RNN layer. + **kwargs: other parameters for the RNN layer (e.g., rnn_type, num_layers, dropout, bidirectional). + + Examples: + >>> from pyhealth.datasets import create_sample_dataset + >>> samples = [ + ... { + ... "patient_id": "patient-0", + ... "visit_id": "visit-0", + ... "conditions": ["cond-33", "cond-86", "cond-80"], + ... "procedures": ["proc-12", "proc-45"], + ... "label": 1, + ... }, + ... { + ... "patient_id": "patient-1", + ... "visit_id": "visit-1", + ... "conditions": ["cond-12", "cond-52"], + ... "procedures": ["proc-23"], + ... "label": 0, + ... }, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"conditions": "sequence", "procedures": "sequence"}, + ... output_schema={"label": "binary"}, + ... dataset_name="test" + ... ) + >>> + >>> from pyhealth.datasets import get_dataloader + >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + >>> + >>> model = RNN(dataset=dataset, embedding_dim=128, hidden_dim=64) + >>> + >>> data_batch = next(iter(train_loader)) + >>> + >>> ret = model(**data_batch) + >>> print(ret) + { + 'loss': tensor(...), + 'y_prob': tensor(...), + 'y_true': tensor(...), + 'logit': tensor(...) + } """ def __init__( diff --git a/pyhealth/models/stagenet.py b/pyhealth/models/stagenet.py index 884e1262a..1892e8f3b 100644 --- a/pyhealth/models/stagenet.py +++ b/pyhealth/models/stagenet.py @@ -285,7 +285,7 @@ class StageNet(BaseModel): **kwargs: other parameters for the StageNet layer. Examples: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", @@ -316,7 +316,7 @@ class StageNet(BaseModel): ... ] >>> >>> # dataset - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={ ... "codes": "stagenet", @@ -616,7 +616,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": - from pyhealth.datasets import SampleDataset + from pyhealth.datasets import create_sample_dataset samples = [ { @@ -654,7 +654,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: ] # dataset - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema={ "codes": "stagenet", diff --git a/pyhealth/models/transformer.py b/pyhealth/models/transformer.py index d1e87abea..9f0feeb38 100644 --- a/pyhealth/models/transformer.py +++ b/pyhealth/models/transformer.py @@ -353,7 +353,7 @@ class Transformer(BaseModel): num_layers (int): number of transformer blocks per feature stream. Examples: - >>> from pyhealth.datasets import SampleDataset, get_dataloader + >>> from pyhealth.datasets import create_sample_dataset, get_dataloader >>> samples = [ ... { ... "patient_id": "patient-0", @@ -372,7 +372,7 @@ class Transformer(BaseModel): ... ] >>> input_schema = {"diagnoses": "sequence", "procedures": "sequence"} >>> output_schema = {"label": "binary"} - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples, ... input_schema, ... output_schema, @@ -661,7 +661,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": - from pyhealth.datasets import SampleDataset, get_dataloader + from pyhealth.datasets import create_sample_dataset, get_dataloader samples = [ { @@ -686,7 +686,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: } output_schema: Dict[str, Union[str, type[FeatureProcessor]]] = {"label": "binary"} - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, output_schema=output_schema, diff --git a/pyhealth/processors/base_processor.py b/pyhealth/processors/base_processor.py index 050cb5357..823fcafec 100644 --- a/pyhealth/processors/base_processor.py +++ b/pyhealth/processors/base_processor.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Iterable class Processor(ABC): @@ -33,7 +33,7 @@ class FeatureProcessor(Processor): Example: Tokenization, image loading, normalization. """ - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: """Fit the processor to the samples. Args: diff --git a/pyhealth/processors/image_processor.py b/pyhealth/processors/image_processor.py index 17c6d5e97..d174529a7 100644 --- a/pyhealth/processors/image_processor.py +++ b/pyhealth/processors/image_processor.py @@ -1,3 +1,4 @@ +from functools import partial from pathlib import Path from typing import Any, List, Optional, Union @@ -59,7 +60,9 @@ def __init__( def _build_transform(self) -> transforms.Compose: transform_list = [] if self.mode is not None: - transform_list.append(transforms.Lambda(lambda img: img.convert(self.mode))) + transform_list.append( + transforms.Lambda(partial(_convert_mode, mode=self.mode)) + ) if self.image_size is not None: transform_list.append( transforms.Resize((self.image_size, self.image_size)) @@ -98,3 +101,8 @@ def __repr__(self) -> str: f"to_tensor={self.to_tensor}, normalize={self.normalize}, " f"mean={self.mean}, std={self.std}, mode={self.mode})" ) + + +def _convert_mode(img: Image.Image, mode: str) -> Image.Image: + """Convert a PIL image to the requested mode.""" + return img.convert(mode) diff --git a/pyhealth/processors/label_processor.py b/pyhealth/processors/label_processor.py index ad2df1897..ff32dabf8 100644 --- a/pyhealth/processors/label_processor.py +++ b/pyhealth/processors/label_processor.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List +from typing import Any, Dict, List, Iterable import torch @@ -19,7 +19,7 @@ def __init__(self): super().__init__() self.label_vocab: Dict[Any, int] = {} - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: all_labels = set([sample[field] for sample in samples]) if len(all_labels) != 2: raise ValueError(f"Expected 2 unique labels, got {len(all_labels)}") @@ -54,7 +54,7 @@ def __init__(self): super().__init__() self.label_vocab: Dict[Any, int] = {} - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: all_labels = set([sample[field] for sample in samples]) num_classes = len(all_labels) if all_labels == set(range(num_classes)): @@ -89,7 +89,7 @@ def __init__(self): super().__init__() self.label_vocab: Dict[Any, int] = {} - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: all_labels = set() for sample in samples: for label in sample[field]: diff --git a/pyhealth/processors/nested_sequence_processor.py b/pyhealth/processors/nested_sequence_processor.py index f89b5b127..80a2567bc 100644 --- a/pyhealth/processors/nested_sequence_processor.py +++ b/pyhealth/processors/nested_sequence_processor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Iterable import torch @@ -51,7 +51,7 @@ def __init__(self, padding: int = 0): self._max_inner_len = 1 # Maximum length of inner sequences self._padding = padding # Additional padding beyond observed max - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: """Build vocabulary and determine maximum inner sequence length. Args: @@ -187,7 +187,7 @@ def __init__(self, forward_fill: bool = True, padding: int = 0): self.forward_fill = forward_fill self._padding = padding # Additional padding beyond observed max - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: """Determine maximum inner sequence length. Args: diff --git a/pyhealth/processors/sequence_processor.py b/pyhealth/processors/sequence_processor.py index 06577b793..32bb9e2cf 100644 --- a/pyhealth/processors/sequence_processor.py +++ b/pyhealth/processors/sequence_processor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Iterable import torch @@ -20,8 +20,7 @@ def __init__(self): self.code_vocab: Dict[Any, int] = {"": 0} self._next_index = 1 - - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: for sample in samples: for token in sample[field]: if token is None: @@ -32,7 +31,6 @@ def fit(self, samples: List[Dict[str, Any]], field: str) -> None: self.code_vocab[""] = len(self.code_vocab) - def process(self, value: Any) -> torch.Tensor: """Process token value(s) into tensor of indices. diff --git a/pyhealth/processors/stagenet_processor.py b/pyhealth/processors/stagenet_processor.py index 143eb03f2..fc3bf3d01 100644 --- a/pyhealth/processors/stagenet_processor.py +++ b/pyhealth/processors/stagenet_processor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Iterable import torch @@ -63,7 +63,7 @@ def __init__(self, padding: int = 0): self._max_nested_len = None self._padding = padding # Additional padding beyond observed max - def fit(self, samples: List[Dict], key: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: """Build vocabulary and determine input structure. Args: @@ -72,9 +72,9 @@ def fit(self, samples: List[Dict], key: str) -> None: """ # Examine first non-None sample to determine structure for sample in samples: - if key in sample and sample[key] is not None: + if field in sample and sample[field] is not None: # Unpack tuple: (time, values) - time_data, value_data = sample[key] + time_data, value_data = sample[field] # Determine nesting level for codes if isinstance(value_data, list) and len(value_data) > 0: @@ -92,9 +92,9 @@ def fit(self, samples: List[Dict], key: str) -> None: # Build vocabulary for codes and find max nested length max_inner_len = 0 for sample in samples: - if key in sample and sample[key] is not None: + if field in sample and sample[field] is not None: # Unpack tuple: (time, values) - time_data, value_data = sample[key] + time_data, value_data = sample[field] if self._is_nested: # Nested codes @@ -262,7 +262,7 @@ def __init__(self): self._size = None # Feature dimension (set during fit) self._is_nested = None - def fit(self, samples: List[Dict], key: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: """Determine input structure. Args: @@ -271,9 +271,9 @@ def fit(self, samples: List[Dict], key: str) -> None: """ # Examine first non-None sample to determine structure for sample in samples: - if key in sample and sample[key] is not None: + if field in sample and sample[field] is not None: # Unpack tuple: (time, values) - time_data, value_data = sample[key] + time_data, value_data = sample[field] # Determine nesting level for numerics if isinstance(value_data, list) and len(value_data) > 0: diff --git a/pyhealth/tasks/medical_coding.py b/pyhealth/tasks/medical_coding.py index 45b8d9cd1..23b3cb2e4 100644 --- a/pyhealth/tasks/medical_coding.py +++ b/pyhealth/tasks/medical_coding.py @@ -37,6 +37,7 @@ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: df.filter(pl.col("event_type") == "noteevents") .select("patient_id") .unique() + .collect() .to_series() ) ) diff --git a/pyproject.toml b/pyproject.toml index 6ab207053..aa96d0e09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,10 +35,14 @@ dependencies = [ "urllib3~=2.5.0", "numpy~=1.26.4", "tqdm", - "polars~=1.31.0", + "polars~=1.35.2", "pandas~=2.3.1", "pandarallel~=1.6.5", "pydantic~=2.11.7", + "dask[complete]~=2025.11.0", + "litdata~=0.2.58", + "pyarrow~=22.0.0", + "narwhals~=2.13.0", ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] diff --git a/tests/core/test_adacare.py b/tests/core/test_adacare.py index 19754cb1d..daf60ecee 100644 --- a/tests/core/test_adacare.py +++ b/tests/core/test_adacare.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import AdaCare @@ -51,7 +51,7 @@ def setUp(self): } self.output_schema = {"label": "binary"} - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_base_dataset.py b/tests/core/test_base_dataset.py new file mode 100644 index 000000000..7e69e7c84 --- /dev/null +++ b/tests/core/test_base_dataset.py @@ -0,0 +1,134 @@ +import tempfile +import unittest +from unittest.mock import patch + +import polars as pl +import pandas as pd +import dask.dataframe as dd + +from pyhealth.datasets.base_dataset import BaseDataset + + +class MockDataset(BaseDataset): + """Dataset that bypasses file loading for tests.""" + + def __init__(self, data: dd.DataFrame, **kwargs): + self._data = data + super().__init__(**kwargs) + + def load_data(self) -> dd.DataFrame: + return self._data + + +class TestBaseDataset(unittest.TestCase): + def _single_row_data(self) -> dd.DataFrame: + return dd.from_pandas( + pd.DataFrame( + { + "patient_id": ["1"], + "event_type": ["test"], + "timestamp": [None], + "test/value": [0], + } + ), + npartitions=1, + ) + + def test_cache_dir_varies_with_core_identifiers(self): + base_kwargs = dict( + tables=["table_a"], + dataset_name="CacheDataset", + dev=False, + ) + + with tempfile.TemporaryDirectory() as cache_root, patch( + "pyhealth.datasets.base_dataset.platformdirs.user_cache_dir", + return_value=cache_root, + ): + datasets = [ + MockDataset( + data=self._single_row_data(), + root="/data/root_a", + **base_kwargs, + ), + MockDataset( + data=self._single_row_data(), + root="/data/root_b", # different root + **base_kwargs, + ), + MockDataset( + data=self._single_row_data(), + root="/data/root_a", + tables=["table_b"], # different tables + dataset_name="CacheDataset", + dev=False, + ), + MockDataset( + data=self._single_row_data(), + root="/data/root_a", + tables=["table_a"], + dataset_name="OtherDataset", # different dataset name + dev=False, + ), + MockDataset( + data=self._single_row_data(), + root="/data/root_a", + tables=["table_a"], + dataset_name="CacheDataset", + dev=True, # different dev flag + ), + ] + + cache_dirs = [ds.cache_dir for ds in datasets] + self.assertEqual( + len(cache_dirs), + len(set(cache_dirs)), + "cache_dir should change when root/tables/dataset_name/dev change", + ) + + def test_event_df_cache_is_physically_sorted(self): + unsorted_data = dd.from_pandas( + pd.DataFrame( + { + "patient_id": ["3", "1", "2", "1"], + "event_type": ["test"] * 4, + "timestamp": [None] * 4, + "test/value": [10, 20, 30, 40], + } + ), + npartitions=1, + ) + original_order = unsorted_data["patient_id"].compute().tolist() + + with tempfile.TemporaryDirectory() as cache_root, patch( + "pyhealth.datasets.base_dataset.platformdirs.user_cache_dir", + return_value=cache_root, + ): + dataset = MockDataset( + data=unsorted_data, + root="/data/root_sort", + tables=["table_a"], + dataset_name="SortingDataset", + dev=False, + ) + + # Trigger caching of global_event_df.parquet + _ = dataset.global_event_df + cache_path = dataset.cache_dir / "global_event_df.parquet" + self.assertTrue(cache_path.exists(), "global_event_df cache should be created") + + cached_df = pl.read_parquet(cache_path) + cached_order = cached_df["patient_id"].to_list() + + self.assertNotEqual( + cached_order, original_order, "cache should not keep the unsorted order" + ) + self.assertEqual( + cached_order, + sorted(cached_order), + "cached global_event_df parquet must be sorted by patient_id", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index b9b355bce..5bdab5640 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -1,15 +1,16 @@ import unittest import tempfile -import os +import shutil from pathlib import Path from unittest.mock import patch import polars as pl +import dask.dataframe as dd +import torch from tests.base import BaseTestCase from pyhealth.datasets.base_dataset import BaseDataset from pyhealth.tasks.base_task import BaseTask from pyhealth.datasets.sample_dataset import SampleDataset -from pyhealth.data import Patient class MockTask(BaseTask): @@ -19,11 +20,13 @@ def __init__(self, task_name="test_task"): self.task_name = task_name self.input_schema = {"test_attribute": "raw"} self.output_schema = {"test_label": "binary"} + self.call_count = 0 def __call__(self, patient): """Return mock samples based on patient data.""" # Extract patient's test data from the patient's data source patient_data = patient.data_source + self.call_count += 1 samples = [] for row in patient_data.iter_rows(named=True): @@ -40,45 +43,35 @@ def __call__(self, patient): class MockDataset(BaseDataset): """Mock dataset for testing purposes.""" - def __init__(self): - # Initialize without calling parent __init__ to avoid file dependencies - self.dataset_name = "TestDataset" - self.dev = False - - # Create realistic test data with patient_id, test_attribute, and test_label - self._collected_global_event_df = pl.DataFrame( - { - "patient_id": ["1", "2", "1", "2"], - "event_type": ["test", "test", "test", "test"], - "timestamp": [None, None, None, None], - "test/test_attribute": [ - "pat_1_attr_1", - "pat_2_attr_1", - "pat_1_attr_2", - "pat_2_attr_2", - ], - "test/test_label": [0, 1, 1, 0], - } + def __init__(self, cache_dir: str | Path | None = None): + super().__init__( + root="", + tables=[], + dataset_name="TestDataset", + cache_dir=cache_dir, + dev=False, ) - self._unique_patient_ids = ["1", "2"] - - @property - def collected_global_event_df(self): - return self._collected_global_event_df - - @property - def unique_patient_ids(self): - return self._unique_patient_ids - - def iter_patients(self, df=None): - """Mock patient iterator that returns real Patient objects.""" - if df is None: - df = self.collected_global_event_df - grouped = df.group_by("patient_id") - for patient_id, patient_df in grouped: - patient_id = patient_id[0] - yield Patient(patient_id=patient_id, data_source=patient_df) + def load_data(self) -> dd.DataFrame: + import pandas as pd + + return dd.from_pandas( + pd.DataFrame( + { + "patient_id": ["1", "2", "1", "2"], + "event_type": ["test", "test", "test", "test"], + "timestamp": [None, None, None, None], + "test/test_attribute": [ + "pat_1_attr_1", + "pat_2_attr_1", + "pat_1_attr_2", + "pat_2_attr_2", + ], + "test/test_label": [0, 1, 1, 0], + } + ), + npartitions=1, + ) class TestCachingFunctionality(BaseTestCase): @@ -86,17 +79,19 @@ class TestCachingFunctionality(BaseTestCase): def setUp(self): """Set up test fixtures.""" - self.dataset = MockDataset() + self.temp_dir = Path(tempfile.mkdtemp()) + self.dataset = MockDataset(cache_dir=self.temp_dir) self.task = MockTask() - self.temp_dir = tempfile.mkdtemp() def tearDown(self): """Clean up test fixtures.""" - # Clean up temporary directory - import shutil - shutil.rmtree(self.temp_dir, ignore_errors=True) + def _task_cache_dir(self) -> Path: + cache_dir = self.temp_dir / "task_cache" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + def test_set_task_signature(self): """Test that set_task has the correct method signature.""" import inspect @@ -104,7 +99,15 @@ def test_set_task_signature(self): sig = inspect.signature(BaseDataset.set_task) params = list(sig.parameters.keys()) - expected_params = ["self", "task", "num_workers", "cache_dir", "cache_format", "input_processors", "output_processors"] + expected_params = [ + "self", + "task", + "num_workers", + "cache_dir", + "cache_format", + "input_processors", + "output_processors", + ] self.assertEqual(params, expected_params) # Check default values @@ -115,158 +118,62 @@ def test_set_task_signature(self): self.assertEqual(sig.parameters["input_processors"].default, None) self.assertEqual(sig.parameters["output_processors"].default, None) - def test_set_task_no_caching(self): - """Test set_task without caching (cache_dir=None).""" - sample_dataset = self.dataset.set_task(self.task) + def test_set_task_writes_cache_and_metadata(self): + """Ensure set_task materializes cache files and schema metadata.""" + cache_dir = self._task_cache_dir() + sample_dataset = self.dataset.set_task( + self.task, cache_dir=cache_dir, cache_format="parquet" + ) self.assertIsInstance(sample_dataset, SampleDataset) - self.assertEqual(len(sample_dataset), 4) # Two patients, two samples each self.assertEqual(sample_dataset.dataset_name, "TestDataset") + self.assertEqual(sample_dataset.task_name, self.task.task_name) + self.assertEqual(len(sample_dataset), 4) + self.assertEqual(self.task.call_count, 2) + + # Cache artifacts should be present for StreamingDataset + self.assertTrue((cache_dir / "index.json").exists()) + self.assertTrue((cache_dir / "schema.pkl").exists()) - # Check that samples have the correct structure + # Check processed sample structure and metadata persisted sample = sample_dataset[0] self.assertIn("test_attribute", sample) self.assertIn("test_label", sample) self.assertIn("patient_id", sample) - - def test_full_parquet_caching_cycle(self): - """Test complete save and load cycle with parquet caching.""" - cache_path = Path(self.temp_dir) / f"{self.task.task_name}.parquet" - - # Step 1: First call - should generate samples and save to cache - self.assertFalse(cache_path.exists(), "Cache file should not exist initially") - - sample_dataset_1 = self.dataset.set_task( - self.task, cache_dir=self.temp_dir, cache_format="parquet" - ) - - # Verify cache file was created + self.assertIsInstance(sample["test_label"], torch.Tensor) + self.assertIn("test_attribute", sample_dataset.input_processors) + self.assertIn("test_label", sample_dataset.output_processors) + self.assertEqual(set(sample_dataset.patient_to_index), {"1", "2"}) self.assertTrue( - cache_path.exists(), "Cache file should be created after first call" - ) - - # Verify the sample dataset is correct - self.assertIsInstance(sample_dataset_1, SampleDataset) - self.assertEqual( - len(sample_dataset_1), 4 - ) # Should have 4 samples from our mock data - - # Step 2: Second call - should load from cache (not regenerate) - sample_dataset_2 = self.dataset.set_task( - self.task, cache_dir=self.temp_dir, cache_format="parquet" - ) - - # Verify the loaded dataset matches the original - self.assertIsInstance(sample_dataset_2, SampleDataset) - self.assertEqual(len(sample_dataset_2), 4) - - # Step 3: Verify the actual cached data is correct - # Load the parquet file directly to check its contents - cached_df = pl.read_parquet(cache_path) - cached_samples = cached_df.to_dicts() - - self.assertEqual(len(cached_samples), 4) - - # Verify sample content matches expected structure - for sample in cached_samples: - self.assertIn("test_attribute", sample) - self.assertIn("test_label", sample) - self.assertIn("patient_id", sample) - self.assertIn(sample["patient_id"], ["1", "2"]) - self.assertIn(sample["test_label"], [0, 1]) - - def test_full_pickle_caching_cycle(self): - """Test complete save and load cycle with pickle caching.""" - cache_path = Path(self.temp_dir) / f"{self.task.task_name}.pickle" - - # Step 1: First call - should generate samples and save to cache - self.assertFalse(cache_path.exists(), "Cache file should not exist initially") - - sample_dataset_1 = self.dataset.set_task( - self.task, cache_dir=self.temp_dir, cache_format="pickle" + all(len(indexes) == 2 for indexes in sample_dataset.patient_to_index.values()) ) + self.assertEqual(sample_dataset.record_to_index, {}) - # Verify cache file was created - self.assertTrue( - cache_path.exists(), "Cache file should be created after first call" - ) - - # Verify the sample dataset is correct - self.assertIsInstance(sample_dataset_1, SampleDataset) - self.assertEqual( - len(sample_dataset_1), 4 - ) # Should have 4 samples from our mock data - - # Step 2: Second call - should load from cache (not regenerate) - sample_dataset_2 = self.dataset.set_task( - self.task, cache_dir=self.temp_dir, cache_format="pickle" - ) - - # Verify the loaded dataset matches the original - self.assertIsInstance(sample_dataset_2, SampleDataset) - self.assertEqual(len(sample_dataset_2), 4) - - # Step 3: Verify the actual cached data is correct - # Load the pickle file directly to check its contents - import pickle - - with open(cache_path, "rb") as f: - cached_samples = pickle.load(f) - - self.assertEqual(len(cached_samples), 4) - - # Verify sample content matches expected structure - for sample in cached_samples: - self.assertIn("test_attribute", sample) - self.assertIn("test_label", sample) - self.assertIn("patient_id", sample) - self.assertIn(sample["patient_id"], ["1", "2"]) - self.assertIn(sample["test_label"], [0, 1]) - - def test_set_task_invalid_cache_format(self): - """Test set_task with invalid cache format.""" - # This should not raise an error during set_task call, - # but should log a warning when trying to save - sample_dataset = self.dataset.set_task( - self.task, cache_dir=self.temp_dir, cache_format="invalid_format" - ) - - self.assertIsInstance(sample_dataset, SampleDataset) - self.assertEqual(len(sample_dataset), 4) # Generated samples - - @patch("polars.read_parquet") - def test_set_task_cache_load_failure_fallback(self, mock_read_parquet): - """Test fallback to generation when cache loading fails.""" - # Make read_parquet raise an exception - mock_read_parquet.side_effect = Exception("Failed to read cache") - - # Create a dummy cache file - cache_path = Path(self.temp_dir) / f"{self.task.task_name}.parquet" - cache_path.touch() - - sample_dataset = self.dataset.set_task( - self.task, cache_dir=self.temp_dir, cache_format="parquet" - ) - - # Should still work by falling back to generation - self.assertIsInstance(sample_dataset, SampleDataset) - self.assertEqual(len(sample_dataset), 4) # Generated samples - - def test_cache_directory_creation(self): - """Test that cache directory is created if it doesn't exist.""" - nested_cache_dir = os.path.join(self.temp_dir, "nested", "cache", "dir") - - # Ensure the nested directory doesn't exist - self.assertFalse(os.path.exists(nested_cache_dir)) + def test_default_cache_dir_is_used(self): + """When cache_dir is omitted, default cache dir should be used.""" + task_cache = self.dataset.cache_dir / "tasks" / self.task.task_name + sample_dataset = self.dataset.set_task(self.task) - with patch("polars.DataFrame.write_parquet"): - sample_dataset = self.dataset.set_task( - self.task, cache_dir=nested_cache_dir, cache_format="parquet" + self.assertTrue(task_cache.exists()) + self.assertTrue((task_cache / "index.json").exists()) + self.assertTrue((self.dataset.cache_dir / "global_event_df.parquet").exists()) + self.assertEqual(len(sample_dataset), 4) + + def test_reuses_existing_cache_without_regeneration(self): + """Second call should reuse cached samples instead of recomputing.""" + cache_dir = self._task_cache_dir() + _ = self.dataset.set_task(self.task, cache_dir=cache_dir) + self.assertEqual(self.task.call_count, 2) + + with patch.object( + self.task, "__call__", side_effect=AssertionError("Task should not rerun") + ): + cached_dataset = self.dataset.set_task( + self.task, cache_dir=cache_dir, cache_format="parquet" ) - # Directory should be created - self.assertTrue(os.path.exists(nested_cache_dir)) - self.assertIsInstance(sample_dataset, SampleDataset) + self.assertEqual(len(cached_dataset), 4) + self.assertEqual(self.task.call_count, 2) if __name__ == "__main__": diff --git a/tests/core/test_chestxray14.py b/tests/core/test_chestxray14.py index 77454c1ba..ad2907cb9 100644 --- a/tests/core/test_chestxray14.py +++ b/tests/core/test_chestxray14.py @@ -7,6 +7,7 @@ import os import shutil import unittest +import tempfile import numpy as np from PIL import Image @@ -124,19 +125,22 @@ def test_task_given_invalid_disease(self): _ = ChestXray14BinaryClassification(disease="toothache") def test_task_classify_cardiomegaly(self): + cache_dir = tempfile.mkdtemp() task = ChestXray14BinaryClassification(disease="cardiomegaly") - samples = self.dataset.set_task(task) + samples = self.dataset.set_task(task, cache_dir=cache_dir) self.assertEqual(len(samples), 10) self.assertEqual(sum(sample["label"] for sample in samples), 3) def test_task_classify_hernia(self): + cache_dir = tempfile.mkdtemp() task = ChestXray14BinaryClassification(disease="hernia") - samples = self.dataset.set_task(task) + samples = self.dataset.set_task(task, cache_dir=cache_dir) self.assertEqual(len(samples), 10) self.assertEqual(sum(sample["label"] for sample in samples), 6) def test_task_classify_all(self): - samples = self.dataset.set_task() + cache_dir = tempfile.mkdtemp() + samples = self.dataset.set_task(cache_dir=cache_dir) self.assertEqual(len(samples), 10) actual_labels = [sample["labels"].tolist() for sample in samples] diff --git a/tests/core/test_cnn.py b/tests/core/test_cnn.py index 4bf0bfe89..4a53e5444 100644 --- a/tests/core/test_cnn.py +++ b/tests/core/test_cnn.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import CNN @@ -33,7 +33,7 @@ def setUp(self): } self.output_schema = {"label": "binary"} - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -147,7 +147,7 @@ def test_model_with_image_input(self): input_schema = {"image": "image"} output_schema = {"label": "binary"} - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, output_schema=output_schema, @@ -209,7 +209,7 @@ def test_model_with_mixed_inputs(self): } output_schema = {"label": "binary"} - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, output_schema=output_schema, diff --git a/tests/core/test_contrawr.py b/tests/core/test_contrawr.py index 94ad0f8da..25f1bb6f8 100644 --- a/tests/core/test_contrawr.py +++ b/tests/core/test_contrawr.py @@ -2,7 +2,7 @@ import numpy as np import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import ContraWR @@ -33,7 +33,7 @@ def setUp(self): self.input_schema = {"epoch_signal": "tensor"} self.output_schema = {"label": "multiclass"} - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_covariate_label.py b/tests/core/test_covariate_label.py index e1c8af02b..b5aa0aaf6 100644 --- a/tests/core/test_covariate_label.py +++ b/tests/core/test_covariate_label.py @@ -2,7 +2,7 @@ import numpy as np import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import MLP from pyhealth.calib.predictionset.covariate import CovariateLabel, fit_kde from pyhealth.calib.utils import extract_embeddings @@ -67,7 +67,7 @@ def setUp(self): self.output_schema = {"label": "multiclass"} # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -150,7 +150,7 @@ def test_initialization_non_multiclass_raises_error(self): "label": 1, }, ] - binary_dataset = SampleDataset( + binary_dataset = create_sample_dataset( samples=binary_samples, input_schema={"conditions": "sequence", "procedures": "tensor"}, output_schema={"label": "binary"}, @@ -182,7 +182,7 @@ def test_calibrate_marginal(self): # Calibrate on first 4 samples cal_indices = [0, 1, 2, 3] - cal_dataset = torch.utils.data.Subset(self.dataset, cal_indices) + cal_dataset = self.dataset.subset(cal_indices) # Extract embeddings cal_embeddings = self._get_embeddings(cal_dataset) @@ -240,7 +240,7 @@ def test_forward_returns_predset(self): # Calibrate cal_indices = [0, 1, 2, 3] - cal_dataset = torch.utils.data.Subset(self.dataset, cal_indices) + cal_dataset = self.dataset.subset(cal_indices) # Extract embeddings cal_embeddings = self._get_embeddings(cal_dataset) @@ -281,7 +281,7 @@ def test_prediction_sets_nonempty(self): # Calibrate on first 4 samples cal_indices = [0, 1, 2, 3] - cal_dataset = torch.utils.data.Subset(self.dataset, cal_indices) + cal_dataset = self.dataset.subset(cal_indices) # Extract embeddings cal_embeddings = self._get_embeddings(cal_dataset) @@ -295,7 +295,7 @@ def test_prediction_sets_nonempty(self): # Test on remaining samples test_indices = [4, 5] - test_dataset = torch.utils.data.Subset(self.dataset, test_indices) + test_dataset = self.dataset.subset(test_indices) test_loader = get_dataloader(test_dataset, batch_size=2, shuffle=False) with torch.no_grad(): @@ -359,7 +359,7 @@ def test_edge_case_empty_class(self): # Use subset that doesn't have all classes cal_indices = [0, 1] # Only classes 0 and 1 - cal_dataset = torch.utils.data.Subset(self.dataset, cal_indices) + cal_dataset = self.dataset.subset(cal_indices) # Extract embeddings cal_embeddings = self._get_embeddings(cal_dataset) diff --git a/tests/core/test_early_stopping.py b/tests/core/test_early_stopping.py index 16e9afe75..7a2540e2f 100644 --- a/tests/core/test_early_stopping.py +++ b/tests/core/test_early_stopping.py @@ -1,6 +1,6 @@ import unittest -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import MLP from pyhealth.trainer import Trainer @@ -29,14 +29,14 @@ def setUp(self): self.output_schema = {"label": "binary"} # Split into train and val - self.train_dataset = SampleDataset( + self.train_dataset = create_sample_dataset( samples=self.samples[:80], input_schema=self.input_schema, output_schema=self.output_schema, dataset_name="train", ) - self.val_dataset = SampleDataset( + self.val_dataset = create_sample_dataset( samples=self.samples[80:], input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_gamenet.py b/tests/core/test_gamenet.py index 1bee65d8f..f456189c7 100644 --- a/tests/core/test_gamenet.py +++ b/tests/core/test_gamenet.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import GAMENet @@ -38,7 +38,7 @@ def setUp(self): } self.output_schema = {"drugs": "multilabel"} - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_gnn.py b/tests/core/test_gnn.py index 97cdbce5c..ef1468ca9 100644 --- a/tests/core/test_gnn.py +++ b/tests/core/test_gnn.py @@ -3,7 +3,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import GCN, GAT @@ -37,7 +37,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # binary classification # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -258,7 +258,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # binary classification # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_integrated_gradients.py b/tests/core/test_integrated_gradients.py index 82365b235..8cbea6600 100644 --- a/tests/core/test_integrated_gradients.py +++ b/tests/core/test_integrated_gradients.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import MLP, StageNet from pyhealth.interpret.methods import IntegratedGradients @@ -43,7 +43,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -249,7 +249,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_interp_metrics.py b/tests/core/test_interp_metrics.py index 01a22184b..9c46cf4ee 100644 --- a/tests/core/test_interp_metrics.py +++ b/tests/core/test_interp_metrics.py @@ -11,7 +11,7 @@ import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.interpret.methods import IntegratedGradients from pyhealth.metrics.interpretability import ( ComprehensivenessMetric, @@ -99,7 +99,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_legacy_mode_resolution.py b/tests/core/test_legacy_mode_resolution.py index e00644f24..9e411b7b3 100644 --- a/tests/core/test_legacy_mode_resolution.py +++ b/tests/core/test_legacy_mode_resolution.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from pyhealth.datasets.sample_dataset import SampleDataset +from pyhealth.datasets.sample_dataset import create_sample_dataset from pyhealth.models.base_model import BaseModel from pyhealth.processors import ( BinaryLabelProcessor, @@ -38,7 +38,7 @@ def _build_dataset(self, output_processor, key="label"): ] input_schema = {"text": "raw"} output_schema = {key: output_processor} - return SampleDataset(samples, input_schema, output_schema) + return create_sample_dataset(samples, input_schema, output_schema) def test_string_schema_sets_mode(self): ds = self._build_dataset("binary") @@ -66,7 +66,7 @@ def test_unregistered_processor_leaves_mode_none(self): def test_multiclass_loss_selection(self): samples = [{"label": i % 3, "text": f"row{i}"} for i in range(6)] - ds = SampleDataset( + ds = create_sample_dataset( samples, {"text": "raw"}, {"label": MultiClassLabelProcessor} ) model = BaseModel(dataset=ds) @@ -78,7 +78,7 @@ def test_multilabel_loss_selection(self): {"label": [0, 2], "text": "row0"}, {"label": [1], "text": "row1"}, ] - ds = SampleDataset(samples, {"text": "raw"}, {"label": MultiLabelProcessor}) + ds = create_sample_dataset(samples, {"text": "raw"}, {"label": MultiLabelProcessor}) model = BaseModel(dataset=ds) self.assertEqual(model.mode, "multilabel") self.assertEqual( @@ -90,7 +90,7 @@ def test_regression_loss_selection(self): {"label": 0.5, "text": "r0"}, {"label": 1.2, "text": "r1"}, ] - ds = SampleDataset( + ds = create_sample_dataset( samples, {"text": "raw"}, {"label": RegressionLabelProcessor} ) model = BaseModel(dataset=ds) diff --git a/tests/core/test_logistic_regression.py b/tests/core/test_logistic_regression.py index ff9746368..91ff954bb 100644 --- a/tests/core/test_logistic_regression.py +++ b/tests/core/test_logistic_regression.py @@ -2,7 +2,7 @@ import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import LogisticRegression @@ -36,7 +36,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # binary classification # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -164,7 +164,7 @@ def test_regression_task(self): } output_schema = {"score": "regression"} - regression_dataset = SampleDataset( + regression_dataset = create_sample_dataset( samples=regression_samples, input_schema=input_schema, output_schema=output_schema, diff --git a/tests/core/test_micron.py b/tests/core/test_micron.py index 938923c6e..59107f907 100644 --- a/tests/core/test_micron.py +++ b/tests/core/test_micron.py @@ -4,7 +4,7 @@ import torch import numpy as np -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import MICRON from pyhealth.processors.base_processor import FeatureProcessor @@ -47,7 +47,7 @@ def setUp(self): "drugs": "multilabel" } - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_mlp.py b/tests/core/test_mlp.py index 2320b577c..892decb99 100644 --- a/tests/core/test_mlp.py +++ b/tests/core/test_mlp.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import MLP @@ -35,7 +35,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # binary classification # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_molerec.py b/tests/core/test_molerec.py index 1ebbb2a41..c55457044 100644 --- a/tests/core/test_molerec.py +++ b/tests/core/test_molerec.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import MoleRec @@ -39,7 +39,7 @@ def setUp(self): } self.output_schema = {"drugs": "multilabel"} - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_multi_hot_embeddings.py b/tests/core/test_multi_hot_embeddings.py index fbb4bde20..0f2aee258 100644 --- a/tests/core/test_multi_hot_embeddings.py +++ b/tests/core/test_multi_hot_embeddings.py @@ -2,7 +2,7 @@ import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models.embedding import EmbeddingModel from pyhealth.models.mlp import MLP @@ -39,7 +39,7 @@ def setUp(self) -> None: self.input_schema = {"ethnicity": "multi_hot"} self.output_schema = {"label": "binary"} - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_processor_schemas.py b/tests/core/test_processor_schemas.py index 54aa3618c..a46820884 100644 --- a/tests/core/test_processor_schemas.py +++ b/tests/core/test_processor_schemas.py @@ -13,7 +13,7 @@ import numpy as np from pyhealth.datasets import MIMIC3Dataset -from pyhealth.datasets.sample_dataset import SampleDataset +from pyhealth.datasets.sample_dataset import create_sample_dataset from pyhealth.processors import TextProcessor, MultiLabelProcessor, TimeseriesProcessor from pyhealth.tasks.medical_coding import MIMIC3ICD9Coding from pyhealth.tasks.base_task import BaseTask @@ -45,6 +45,7 @@ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: df.filter(pl.col("event_type") == "noteevents") .select("patient_id") .unique() + .collect() .to_series() ) ) @@ -116,6 +117,7 @@ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: df.filter(pl.col("event_type") == "noteevents") .select("patient_id") .unique() + .collect() .to_series() ) ) @@ -326,7 +328,7 @@ def __call__(self, patient: Patient) -> List[Dict]: } task = TestTimeseriesTask() - sample_dataset = SampleDataset( + sample_dataset = create_sample_dataset( samples=samples, input_schema=task.input_schema, output_schema=task.output_schema, @@ -379,7 +381,7 @@ def __call__(self, patient: Patient) -> List[Dict]: } task = TestTimeseriesTask() - sample_dataset = SampleDataset( + sample_dataset = create_sample_dataset( samples=samples, input_schema=task.input_schema, output_schema=task.output_schema, diff --git a/tests/core/test_processor_transfer.py b/tests/core/test_processor_transfer.py index d7bb41627..07ed867b1 100644 --- a/tests/core/test_processor_transfer.py +++ b/tests/core/test_processor_transfer.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import MLP @@ -63,7 +63,7 @@ def setUp(self): def test_basic_processor_transfer(self): """Test basic processor transfer from train to test dataset.""" # Create training dataset (processors will be fitted) - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=self.train_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -71,7 +71,7 @@ def test_basic_processor_transfer(self): ) # Create test dataset with transferred processors - test_dataset = SampleDataset( + test_dataset = create_sample_dataset( samples=self.test_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -97,7 +97,7 @@ def test_basic_processor_transfer(self): def test_processor_vocabulary_consistency(self): """Test that transferred processors maintain vocabulary consistency.""" # Create training dataset - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=self.train_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -109,7 +109,7 @@ def test_processor_vocabulary_consistency(self): train_vocab = train_dataset.input_processors["conditions"].code_vocab # Create test dataset with transferred processors - test_dataset = SampleDataset( + test_dataset = create_sample_dataset( samples=self.test_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -135,7 +135,7 @@ def test_processor_vocabulary_consistency(self): def test_model_training_with_transferred_processors(self): """Test end-to-end training and inference with processor transfer.""" # Create training dataset - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=self.train_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -143,7 +143,7 @@ def test_model_training_with_transferred_processors(self): ) # Create test dataset with transferred processors - test_dataset = SampleDataset( + test_dataset = create_sample_dataset( samples=self.test_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -176,7 +176,7 @@ def test_model_training_with_transferred_processors(self): def test_without_processor_transfer(self): """Test that without transfer, each dataset fits its own processors.""" # Create training dataset - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=self.train_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -184,7 +184,7 @@ def test_without_processor_transfer(self): ) # Create test dataset WITHOUT transferred processors - test_dataset = SampleDataset( + test_dataset = create_sample_dataset( samples=self.test_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -210,7 +210,7 @@ def test_without_processor_transfer(self): def test_partial_processor_transfer(self): """Test transferring only input processors, not output processors.""" # Create training dataset - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=self.train_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -218,7 +218,7 @@ def test_partial_processor_transfer(self): ) # Create test dataset with only input processors transferred - test_dataset = SampleDataset( + test_dataset = create_sample_dataset( samples=self.test_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -242,7 +242,7 @@ def test_partial_processor_transfer(self): def test_empty_processor_dict_transfer(self): """Test passing empty processor dictionaries.""" # Create dataset with empty processor dicts (should fit new ones) - dataset = SampleDataset( + dataset = create_sample_dataset( samples=self.train_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -272,7 +272,7 @@ def test_cross_validation_scenario(self): ] # has 1 and 0 # Fold 1 as train - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=fold1, input_schema=self.input_schema, output_schema=self.output_schema, @@ -280,7 +280,7 @@ def test_cross_validation_scenario(self): ) # Fold 2 as validation with transferred processors - val_dataset = SampleDataset( + val_dataset = create_sample_dataset( samples=fold2, input_schema=self.input_schema, output_schema=self.output_schema, @@ -343,7 +343,7 @@ def test_multimodal_processor_transfer(self): } # Create train dataset - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=multi_train, input_schema=multi_input_schema, output_schema=self.output_schema, @@ -351,7 +351,7 @@ def test_multimodal_processor_transfer(self): ) # Create test dataset with transferred processors - test_dataset = SampleDataset( + test_dataset = create_sample_dataset( samples=multi_test, input_schema=multi_input_schema, output_schema=self.output_schema, @@ -377,13 +377,13 @@ def test_multimodal_processor_transfer(self): def test_backward_compatibility(self): """Test that existing code without processor transfer still works.""" # Old-style usage without any processor parameters - dataset1 = SampleDataset( + dataset1 = create_sample_dataset( samples=self.train_samples, input_schema=self.input_schema, output_schema=self.output_schema, ) - dataset2 = SampleDataset( + dataset2 = create_sample_dataset( samples=self.test_samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -428,7 +428,7 @@ def test_none_processors_explicitly(self): output_schema = {"label": "binary"} # Explicitly pass None - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples, input_schema=input_schema, output_schema=output_schema, @@ -473,14 +473,14 @@ def test_mixed_transfer_and_schema(self): output_schema = {"label": "binary"} # Create train dataset - train_dataset = SampleDataset( + train_dataset = create_sample_dataset( samples=train_samples, input_schema=input_schema, output_schema=output_schema, ) # Create test dataset with same schema - test_dataset = SampleDataset( + test_dataset = create_sample_dataset( samples=test_samples, input_schema=input_schema, output_schema=output_schema, diff --git a/tests/core/test_safedrug.py b/tests/core/test_safedrug.py index 707f164e1..75e278a19 100644 --- a/tests/core/test_safedrug.py +++ b/tests/core/test_safedrug.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import SafeDrug @@ -33,7 +33,7 @@ def setUp(self): } self.output_schema = {"drugs": "multilabel"} - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_sample_builder.py b/tests/core/test_sample_builder.py new file mode 100644 index 000000000..446320ccf --- /dev/null +++ b/tests/core/test_sample_builder.py @@ -0,0 +1,83 @@ +import os +import pickle +import tempfile +import unittest + +from pyhealth.datasets.sample_dataset import SampleBuilder + + +class TestSampleBuilder(unittest.TestCase): + def setUp(self): + self.samples = [ + {"patient_id": "p1", "record_id": "r1", "feature": "a", "label": 1}, + {"patient_id": "p1", "record_id": "r2", "feature": "b", "label": 0}, + {"patient_id": "p2", "record_id": "r3", "feature": "c", "label": 1}, + ] + self.input_schema = {"feature": "raw"} + self.output_schema = {"label": "raw"} + + def test_fit_and_transform(self): + builder = SampleBuilder( + input_schema=self.input_schema, output_schema=self.output_schema + ) + + with self.assertRaises(RuntimeError): + _ = builder.input_processors # Access before fit should fail + + builder.fit(self.samples) + + self.assertIn("feature", builder.input_processors) + self.assertIn("label", builder.output_processors) + + self.assertEqual(builder.patient_to_index["p1"], [0, 1]) + self.assertEqual(builder.record_to_index["r3"], [2]) + + transformed = builder.transform({"sample": pickle.dumps(self.samples[0])}) + self.assertEqual(transformed["feature"], "a") + self.assertEqual(transformed["label"], 1) + self.assertEqual(transformed["patient_id"], "p1") + + def test_transform_requires_fit(self): + builder = SampleBuilder( + input_schema=self.input_schema, output_schema=self.output_schema + ) + with self.assertRaises(RuntimeError): + builder.transform({"sample": pickle.dumps(self.samples[0])}) + + def test_index_mappings(self): + builder = SampleBuilder( + input_schema=self.input_schema, output_schema=self.output_schema + ) + builder.fit(self.samples) + + expected_patient = {"p1": [0, 1], "p2": [2]} + expected_record = {"r1": [0], "r2": [1], "r3": [2]} + self.assertEqual(builder.patient_to_index, expected_patient) + self.assertEqual(builder.record_to_index, expected_record) + + def test_save_persists_fitted_state(self): + builder = SampleBuilder( + input_schema=self.input_schema, output_schema=self.output_schema + ) + builder.fit(self.samples) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "schema.pkl") + builder.save(path) + + with open(path, "rb") as f: + metadata = pickle.load(f) + + self.assertEqual(metadata["input_schema"], self.input_schema) + self.assertEqual(metadata["output_schema"], self.output_schema) + self.assertEqual(metadata["patient_to_index"], builder.patient_to_index) + self.assertEqual(metadata["record_to_index"], builder.record_to_index) + + feature_processor = metadata["input_processors"]["feature"] + label_processor = metadata["output_processors"]["label"] + self.assertEqual(feature_processor.process("foo"), "foo") + self.assertEqual(label_processor.process(1), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_sample_dataset.py b/tests/core/test_sample_dataset.py new file mode 100644 index 000000000..8dd4629c0 --- /dev/null +++ b/tests/core/test_sample_dataset.py @@ -0,0 +1,123 @@ +import unittest +import pickle +import random +from pyhealth.datasets.sample_dataset import create_sample_dataset + +class TestSampleDatasetParity(unittest.TestCase): + def setUp(self): + # Create a slightly larger dataset to make shuffling more obvious + self.samples = [ + {"patient_id": f"p{i}", "record_id": f"r{i}", "feature": i, "label": i % 2} + for i in range(20) + ] + self.input_schema = {"feature": "raw"} + self.output_schema = {"label": "raw"} + + def _get_datasets(self): + ds_disk = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + in_memory=False + ) + ds_mem = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + in_memory=True + ) + return ds_disk, ds_mem + + def test_len(self): + ds_disk, ds_mem = self._get_datasets() + self.assertEqual(len(ds_disk), 20) + self.assertEqual(len(ds_mem), 20) + self.assertEqual(len(ds_disk), len(ds_mem)) + + def test_getitem(self): + ds_disk, ds_mem = self._get_datasets() + for i in range(len(self.samples)): + item_disk = ds_disk[i] + item_mem = ds_mem[i] + self.assertEqual(item_disk["feature"], item_mem["feature"]) + self.assertEqual(item_disk["label"], item_mem["label"]) + self.assertEqual(item_disk["patient_id"], item_mem["patient_id"]) + + def test_iter(self): + ds_disk, ds_mem = self._get_datasets() + list_disk = list(ds_disk) + list_mem = list(ds_mem) + + self.assertEqual(len(list_disk), len(list_mem)) + for d, m in zip(list_disk, list_mem): + self.assertEqual(d["feature"], m["feature"]) + + def test_subset_indices(self): + ds_disk, ds_mem = self._get_datasets() + indices = [0, 5, 10, 15, 19] + + sub_disk = ds_disk.subset(indices) + sub_mem = ds_mem.subset(indices) + + self.assertEqual(len(sub_disk), len(sub_mem)) + self.assertEqual(len(sub_disk), 5) + + list_disk = list(sub_disk) + list_mem = list(sub_mem) + + for d, m in zip(list_disk, list_mem): + self.assertEqual(d["feature"], m["feature"]) + + def test_subset_slice(self): + ds_disk, ds_mem = self._get_datasets() + s = slice(2, 18, 2) + + sub_disk = ds_disk.subset(s) + sub_mem = ds_mem.subset(s) + + self.assertEqual(len(sub_disk), len(sub_mem)) + + list_disk = list(sub_disk) + list_mem = list(sub_mem) + + for d, m in zip(list_disk, list_mem): + self.assertEqual(d["feature"], m["feature"]) + + def test_set_shuffle(self): + ds_disk, ds_mem = self._get_datasets() + + # Test shuffle=True + ds_disk.set_shuffle(True) + ds_mem.set_shuffle(True) + + # Iterating should return all elements, but likely in different order than original + # and potentially different order between disk and mem (implementation detail) + # But the set of elements should be identical. + + items_disk = list(ds_disk) + items_mem = list(ds_mem) + + self.assertEqual(len(items_disk), 20) + self.assertEqual(len(items_mem), 20) + + # Check that we have the same set of features + features_disk = sorted([x["feature"] for x in items_disk]) + features_mem = sorted([x["feature"] for x in items_mem]) + features_orig = sorted([x["feature"] for x in self.samples]) + + self.assertEqual(features_disk, features_orig) + self.assertEqual(features_mem, features_orig) + + # Test shuffle=False resets to original order + ds_disk.set_shuffle(False) + ds_mem.set_shuffle(False) + + items_disk_ordered = list(ds_disk) + items_mem_ordered = list(ds_mem) + + for i in range(20): + self.assertEqual(items_disk_ordered[i]["feature"], i) + self.assertEqual(items_mem_ordered[i]["feature"], i) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_stagenet.py b/tests/core/test_stagenet.py index e1a491b49..b3fd4b5a2 100644 --- a/tests/core/test_stagenet.py +++ b/tests/core/test_stagenet.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import StageNet @@ -61,7 +61,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, @@ -223,7 +223,7 @@ def test_time_handling_with_none(self): }, ] - dataset_no_time = SampleDataset( + dataset_no_time = create_sample_dataset( samples=samples_no_time, input_schema={"codes": "stagenet"}, output_schema={"label": "binary"}, @@ -311,7 +311,7 @@ def test_single_feature_input(self): }, ] - dataset_single = SampleDataset( + dataset_single = create_sample_dataset( samples=samples_single, input_schema={"codes": "stagenet"}, output_schema={"label": "binary"}, @@ -451,7 +451,7 @@ def test_processor_padding_with_model_integration(self): # Create dataset - it will use default processor, but we can verify padding works # by checking the processor's configuration - dataset = SampleDataset( + dataset = create_sample_dataset( samples=samples_padding, input_schema={"procedures": "stagenet"}, output_schema={"label": "binary"}, diff --git a/tests/core/test_streaming_parquet_writer.py b/tests/core/test_streaming_parquet_writer.py new file mode 100644 index 000000000..c7c212bd5 --- /dev/null +++ b/tests/core/test_streaming_parquet_writer.py @@ -0,0 +1,55 @@ +import tempfile +from pathlib import Path +import unittest + +import pyarrow as pa +import pyarrow.parquet as pq + +from pyhealth.datasets.base_dataset import _ParquetWriter +from tests.base import BaseTestCase + + +class TestStreamingParquetWriter(BaseTestCase): + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.schema = pa.schema( + [ + ("id", pa.int64()), + ("value", pa.string()), + ] + ) + self.output_path = Path(self.tmpdir.name) / "stream.parquet" + + def tearDown(self): + self.tmpdir.cleanup() + + def test_append_flush_close_and_context_manager(self): + rows = [ + {"id": 1, "value": "a"}, + {"id": 2, "value": "b"}, + {"id": 3, "value": "c"}, + {"id": 4, "value": "d"}, + ] + + with _ParquetWriter( + self.output_path, self.schema, chunk_size=2 + ) as writer: + # First two appends trigger an automatic flush due to chunk_size=2. + writer.append(rows[0]) + writer.append(rows[1]) + + # Flush again after adding a third row to ensure flushing appends + # rather than overwriting previous row groups. + writer.append(rows[2]) + writer.flush() + + # Leave data in the buffer to verify close() flushes it. + writer.append(rows[3]) + + # Context manager should have closed and flushed remaining buffered rows. + self.assertTrue(self.output_path.exists()) + + written_rows = pq.read_table(self.output_path).to_pylist() + + # Every append should be present as a distinct row in order. + self.assertEqual(written_rows, rows) diff --git a/tests/core/test_support2.py b/tests/core/test_support2.py index 6a21c84f6..303e89c90 100644 --- a/tests/core/test_support2.py +++ b/tests/core/test_support2.py @@ -137,11 +137,10 @@ def test_survival_preprocess_2m(self): sample_dataset = dataset.set_task(task) self.assertIsNotNone(sample_dataset) - self.assertTrue(hasattr(sample_dataset, "samples")) - self.assertEqual(len(sample_dataset.samples), 3) + self.assertEqual(len(sample_dataset), 3) # Check first sample structure - sample = sample_dataset.samples[0] + sample = next(iter(sample_dataset)) required_keys = [ "patient_id", "demographics", @@ -156,7 +155,7 @@ def test_survival_preprocess_2m(self): self.assertIn(key, sample, f"Sample should contain key: {key}") # Verify survival probabilities are in valid range [0, 1] - for s in sample_dataset.samples: + for s in sample_dataset: survival_prob = s["survival_probability"] self.assertIsInstance(survival_prob, torch.Tensor) prob_value = survival_prob.item() @@ -186,10 +185,10 @@ def test_survival_preprocess_6m(self): sample_dataset = dataset.set_task(task) self.assertIsNotNone(sample_dataset) - self.assertEqual(len(sample_dataset.samples), 3) + self.assertEqual(len(sample_dataset), 3) # Verify all samples have valid survival probabilities - for s in sample_dataset.samples: + for s in sample_dataset: survival_prob = s["survival_probability"] self.assertIsInstance(survival_prob, torch.Tensor) prob_value = survival_prob.item() diff --git a/tests/core/test_transformer.py b/tests/core/test_transformer.py index 25a44495b..a5fa6cc6b 100644 --- a/tests/core/test_transformer.py +++ b/tests/core/test_transformer.py @@ -3,7 +3,7 @@ import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import Transformer from pyhealth.processors.base_processor import FeatureProcessor from pyhealth.interpret.methods import CheferRelevance @@ -42,7 +42,7 @@ def setUp(self): "label": "binary" } - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, diff --git a/tests/core/test_tsv_load.py b/tests/core/test_tsv_load.py index 779b928b6..25c87f1db 100644 --- a/tests/core/test_tsv_load.py +++ b/tests/core/test_tsv_load.py @@ -156,7 +156,7 @@ def test_tsv_load(self): self.assertIsNotNone(dataset.config) # Test that we can collect the dataframe - collected_df = dataset.collected_global_event_df + collected_df = dataset.global_event_df.collect() self.assertIsInstance(collected_df, pl.DataFrame) self.assertGreater( collected_df.height, 0, "Dataset should have at least one row" @@ -201,7 +201,7 @@ def test_tsv_file_detection(self): dev=False, ) - collected_df = dataset.collected_global_event_df + collected_df = dataset.global_event_df.collect() # Verify we have the expected number of patients self.assertEqual(collected_df["patient_id"].n_unique(), 5) @@ -231,7 +231,7 @@ def test_multiple_tsv_tables(self): dev=False, ) - collected_df = dataset.collected_global_event_df + collected_df = dataset.global_event_df.collect() # Should have data from both tables self.assertGreater(collected_df.height, 5) # More than just patients table