diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 515944db..b85b97d5 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -471,22 +471,21 @@ def load_table(self, table_name: str) -> dd.DataFrame: attribute_cols = table_cfg.attributes # Timestamp expression + # .astype(str) will convert `pd.NA` to "", which will raise error in to_datetime + # use .astype("string") instead, which keeps `pd.NA` as is. if timestamp_col: if isinstance(timestamp_col, list): # Concatenate all timestamp parts in order with no separator timestamp_series: dd.Series = functools.reduce( - operator.add, (df[col].astype(str) for col in timestamp_col) + operator.add, (df[col].astype("string") for col in timestamp_col) ) else: - # Single timestamp column - don't convert to string yet - timestamp_series: dd.Series = df[timestamp_col] + timestamp_series: dd.Series = df[timestamp_col].astype("string") - # Convert to datetime, coercing NA/invalid values to NaT - # This avoids the "" string literal issue when NA is cast to str timestamp_series: dd.Series = dd.to_datetime( timestamp_series, format=timestamp_format, - errors="raise", # Convert unparseable values to NaT instead of raising + errors="raise", ) df: dd.DataFrame = df.assign( timestamp=timestamp_series.astype("datetime64[ms]") @@ -496,10 +495,10 @@ def load_table(self, table_name: str) -> dd.DataFrame: # If patient_id_col is None, use row index as patient_id if patient_id_col: - df: dd.DataFrame = df.assign(patient_id=df[patient_id_col].astype(str)) + df: dd.DataFrame = df.assign(patient_id=df[patient_id_col].astype("string")) else: df: dd.DataFrame = df.reset_index(drop=True) - df: dd.DataFrame = df.assign(patient_id=df.index.astype(str)) + df: dd.DataFrame = df.assign(patient_id=df.index.astype("string")) df: dd.DataFrame = df.assign(event_type=table_name) diff --git a/tests/core/test_base_dataset.py b/tests/core/test_base_dataset.py index 7e69e7c8..4f9bb1fd 100644 --- a/tests/core/test_base_dataset.py +++ b/tests/core/test_base_dataset.py @@ -129,6 +129,173 @@ def test_event_df_cache_is_physically_sorted(self): "cached global_event_df parquet must be sorted by patient_id", ) + def test_empty_string_handling(self): + import os + from dataclasses import dataclass + from typing import List + + # Create a temporary directory and a CSV file + with tempfile.TemporaryDirectory() as tmp_dir: + csv_path = os.path.join(tmp_dir, "data.csv") + # Create CSV with empty strings + # pid, time, val + # p1, 2020-01-01, v1 + # p2, "", v2 -> missing time + # "", 2020-01-02, v3 -> missing pid + # p3, 2020-01-03, "" -> missing val + with open(csv_path, "w") as f: + f.write("pid,time,val\n") + f.write("p1,2020-01-01,v1\n") + f.write("p2,,v2\n") + f.write(",2020-01-02,v3\n") + f.write("p3,2020-01-03,\n") + + @dataclass + class TableConfig: + file_path: str + patient_id: str + timestamp: str + timestamp_format: str + attributes: List[str] + join: List = None + + def __post_init__(self): + if self.join is None: + self.join = [] + + @dataclass + class Config: + tables: dict + + config = Config( + tables={ + "table1": TableConfig( + file_path="data.csv", + patient_id="pid", + timestamp="time", + timestamp_format="%Y-%m-%d", + attributes=["val"] + ) + } + ) + + class ConcreteDataset(BaseDataset): + pass + + dataset = ConcreteDataset( + root=tmp_dir, + tables=["table1"], + dataset_name="TestDataset", + cache_dir=tmp_dir + ) + dataset.config = config + + # Load data + # load_table returns a dask dataframe + df = dataset.load_table("table1") + # Compute to get pandas dataframe + pdf = df.compute() + + # Verify + # Row 0: p1, 2020-01-01, v1 + self.assertEqual(pdf.iloc[0]["patient_id"], "p1") + self.assertEqual(pdf.iloc[0]["timestamp"], pd.Timestamp("2020-01-01")) + self.assertEqual(pdf.iloc[0]["table1/val"], "v1") + + # Row 1: p2, NaT, v2 + self.assertEqual(pdf.iloc[1]["patient_id"], "p2") + self.assertTrue(pd.isna(pdf.iloc[1]["timestamp"])) + self.assertEqual(pdf.iloc[1]["table1/val"], "v2") + + # Row 2: , 2020-01-02, v3 + self.assertTrue(pd.isna(pdf.iloc[2]["patient_id"])) + self.assertEqual(pdf.iloc[2]["timestamp"], pd.Timestamp("2020-01-02")) + self.assertEqual(pdf.iloc[2]["table1/val"], "v3") + + # Row 3: p3, 2020-01-03, + self.assertEqual(pdf.iloc[3]["patient_id"], "p3") + self.assertEqual(pdf.iloc[3]["timestamp"], pd.Timestamp("2020-01-03")) + self.assertTrue(pd.isna(pdf.iloc[3]["table1/val"])) + + def test_empty_string_handling_composite_timestamp(self): + import os + from dataclasses import dataclass + from typing import List + + # Create a temporary directory and a CSV file + with tempfile.TemporaryDirectory() as tmp_dir: + csv_path = os.path.join(tmp_dir, "data_composite.csv") + # Create CSV with empty strings in composite timestamp fields + # pid, year, month, day, val + # p1, 2020, 01, 01, v1 -> 2020-01-01 + # p2, 2020, , 02, v2 -> missing month -> NaT + # p3, , 01, 03, v3 -> missing year -> NaT + with open(csv_path, "w") as f: + f.write("pid,year,month,day,val\n") + f.write("p1,2020,01,01,v1\n") + f.write("p2,2020,,02,v2\n") + f.write("p3,,01,03,v3\n") + + @dataclass + class TableConfig: + file_path: str + patient_id: str + timestamp: List[str] + timestamp_format: str + attributes: List[str] + join: List = None + + def __post_init__(self): + if self.join is None: + self.join = [] + + @dataclass + class Config: + tables: dict + + config = Config( + tables={ + "table1": TableConfig( + file_path="data_composite.csv", + patient_id="pid", + timestamp=["year", "month", "day"], + timestamp_format="%Y%m%d", + attributes=["val"] + ) + } + ) + + class ConcreteDataset(BaseDataset): + pass + + dataset = ConcreteDataset( + root=tmp_dir, + tables=["table1"], + dataset_name="TestDatasetComposite", + cache_dir=tmp_dir + ) + dataset.config = config + + # Load data + df = dataset.load_table("table1") + pdf = df.compute() + + # Verify + # Row 0: p1, 2020-01-01 + self.assertEqual(pdf.iloc[0]["patient_id"], "p1") + self.assertEqual(pdf.iloc[0]["timestamp"], pd.Timestamp("2020-01-01")) + self.assertEqual(pdf.iloc[0]["table1/val"], "v1") + + # Row 1: p2, NaT + self.assertEqual(pdf.iloc[1]["patient_id"], "p2") + self.assertTrue(pd.isna(pdf.iloc[1]["timestamp"])) + self.assertEqual(pdf.iloc[1]["table1/val"], "v2") + + # Row 2: p3, NaT + self.assertEqual(pdf.iloc[2]["patient_id"], "p3") + self.assertTrue(pd.isna(pdf.iloc[2]["timestamp"])) + self.assertEqual(pdf.iloc[2]["table1/val"], "v3") + if __name__ == "__main__": unittest.main()