Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions tests/core/test_mimic3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import unittest
import tempfile
import shutil
import subprocess
import os
from pathlib import Path

from pyhealth.datasets import MIMIC3Dataset


class TestMIMIC3Demo(unittest.TestCase):
"""Test MIMIC3 dataset with demo data downloaded from PhysioNet."""

def setUp(self):
"""Download and set up demo dataset for each test."""
self.temp_dir = tempfile.mkdtemp()
self._download_demo_dataset()
self._load_dataset()

def tearDown(self):
"""Clean up downloaded dataset after each test."""
if self.temp_dir and os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)

def _download_demo_dataset(self):
"""Download MIMIC-III demo dataset using wget."""
download_url = "https://physionet.org/files/mimiciii-demo/1.4/"
Copy link

Copilot AI Jul 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The download URL is hardcoded. Consider defining it as a class constant to make it easier to update if the URL changes and to improve maintainability.

Suggested change
download_url = "https://physionet.org/files/mimiciii-demo/1.4/"
download_url = cls.DEMO_DATASET_URL

Copilot uses AI. Check for mistakes.

# Use wget to download the demo dataset recursively
cmd = [
"wget",
"-r",
"-N",
"-c",
"-np",
"--directory-prefix",
self.temp_dir,
download_url,
]

try:
subprocess.run(cmd, check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
raise unittest.SkipTest(f"Failed to download MIMIC-III demo dataset: {e}")
except FileNotFoundError:
raise unittest.SkipTest("wget not available - skipping download test")

# Find the downloaded dataset path
physionet_dir = (
Path(self.temp_dir) / "physionet.org" / "files" / "mimiciii-demo" / "1.4"
)
if physionet_dir.exists():
self.demo_dataset_path = str(physionet_dir)
else:
raise unittest.SkipTest("Downloaded dataset not found in expected location")

def _load_dataset(self):
"""Load the dataset for testing."""
tables = ["diagnoses_icd", "procedures_icd", "prescriptions", "noteevents"]
Copy link

Copilot AI Jul 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The tables list contains magic strings. Consider defining these as class constants to improve maintainability and make it easier to modify the test configuration.

Suggested change
tables = ["diagnoses_icd", "procedures_icd", "prescriptions", "noteevents"]
tables = [
cls.TABLE_DIAGNOSES_ICD,
cls.TABLE_PROCEDURES_ICD,
cls.TABLE_PRESCRIPTIONS,
cls.TABLE_NOTEEVENTS,
]

Copilot uses AI. Check for mistakes.
self.dataset = MIMIC3Dataset(root=self.demo_dataset_path, tables=tables)

def test_stats(self):
"""Test .stats() method execution."""
try:
self.dataset.stats()
except Exception as e:
self.fail(f"dataset.stats() failed: {e}")

def test_get_events(self):
"""Test get_patient and get_events methods with patient 10006."""
# Test get_patient method
patient = self.dataset.get_patient("10006")
Copy link

Copilot AI Jul 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The patient ID "10006" is a magic number. Consider defining it as a class constant (e.g., TEST_PATIENT_ID = "10006") to improve maintainability and make it clear why this specific patient ID is used.

Copilot uses AI. Check for mistakes.
self.assertIsNotNone(patient, msg="Patient 10006 should exist in demo dataset")

# Test get_events method
events = patient.get_events()
self.assertIsNotNone(events, msg="get_events() should not return None")
self.assertIsInstance(events, list, msg="get_events() should return a list")
self.assertGreater(
len(events), 0, msg="get_events() should not return an empty list"
)


if __name__ == "__main__":
unittest.main()