Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,22 +471,21 @@ def load_table(self, table_name: str) -> dd.DataFrame:
attribute_cols = table_cfg.attributes

# Timestamp expression
# .astype(str) will convert `pd.NA` to "<NA>", which will raise error in to_datetime
# use .astype("string") instead, which keeps `pd.NA` as is.
if timestamp_col:
if isinstance(timestamp_col, list):
# Concatenate all timestamp parts in order with no separator
timestamp_series: dd.Series = functools.reduce(
operator.add, (df[col].astype(str) for col in timestamp_col)
operator.add, (df[col].astype("string") for col in timestamp_col)
)
else:
# Single timestamp column - don't convert to string yet
timestamp_series: dd.Series = df[timestamp_col]
timestamp_series: dd.Series = df[timestamp_col].astype("string")

# Convert to datetime, coercing NA/invalid values to NaT
# This avoids the "<NA>" string literal issue when NA is cast to str
timestamp_series: dd.Series = dd.to_datetime(
timestamp_series,
format=timestamp_format,
errors="raise", # Convert unparseable values to NaT instead of raising
errors="raise",
)
df: dd.DataFrame = df.assign(
timestamp=timestamp_series.astype("datetime64[ms]")
Expand All @@ -496,10 +495,10 @@ def load_table(self, table_name: str) -> dd.DataFrame:

# If patient_id_col is None, use row index as patient_id
if patient_id_col:
df: dd.DataFrame = df.assign(patient_id=df[patient_id_col].astype(str))
df: dd.DataFrame = df.assign(patient_id=df[patient_id_col].astype("string"))
else:
df: dd.DataFrame = df.reset_index(drop=True)
df: dd.DataFrame = df.assign(patient_id=df.index.astype(str))
df: dd.DataFrame = df.assign(patient_id=df.index.astype("string"))

df: dd.DataFrame = df.assign(event_type=table_name)

Expand Down
167 changes: 167 additions & 0 deletions tests/core/test_base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,173 @@ def test_event_df_cache_is_physically_sorted(self):
"cached global_event_df parquet must be sorted by patient_id",
)

def test_empty_string_handling(self):
import os
from dataclasses import dataclass
from typing import List

# Create a temporary directory and a CSV file
with tempfile.TemporaryDirectory() as tmp_dir:
csv_path = os.path.join(tmp_dir, "data.csv")
# Create CSV with empty strings
# pid, time, val
# p1, 2020-01-01, v1
# p2, "", v2 -> missing time
# "", 2020-01-02, v3 -> missing pid
# p3, 2020-01-03, "" -> missing val
with open(csv_path, "w") as f:
f.write("pid,time,val\n")
f.write("p1,2020-01-01,v1\n")
f.write("p2,,v2\n")
f.write(",2020-01-02,v3\n")
f.write("p3,2020-01-03,\n")

@dataclass
class TableConfig:
file_path: str
patient_id: str
timestamp: str
timestamp_format: str
attributes: List[str]
join: List = None

def __post_init__(self):
if self.join is None:
self.join = []

@dataclass
class Config:
tables: dict

config = Config(
tables={
"table1": TableConfig(
file_path="data.csv",
patient_id="pid",
timestamp="time",
timestamp_format="%Y-%m-%d",
attributes=["val"]
)
}
)

class ConcreteDataset(BaseDataset):
pass

dataset = ConcreteDataset(
root=tmp_dir,
tables=["table1"],
dataset_name="TestDataset",
cache_dir=tmp_dir
)
dataset.config = config

# Load data
# load_table returns a dask dataframe
df = dataset.load_table("table1")
# Compute to get pandas dataframe
pdf = df.compute()

# Verify
# Row 0: p1, 2020-01-01, v1
self.assertEqual(pdf.iloc[0]["patient_id"], "p1")
self.assertEqual(pdf.iloc[0]["timestamp"], pd.Timestamp("2020-01-01"))
self.assertEqual(pdf.iloc[0]["table1/val"], "v1")

# Row 1: p2, NaT, v2
self.assertEqual(pdf.iloc[1]["patient_id"], "p2")
self.assertTrue(pd.isna(pdf.iloc[1]["timestamp"]))
self.assertEqual(pdf.iloc[1]["table1/val"], "v2")

# Row 2: <NA>, 2020-01-02, v3
self.assertTrue(pd.isna(pdf.iloc[2]["patient_id"]))
self.assertEqual(pdf.iloc[2]["timestamp"], pd.Timestamp("2020-01-02"))
self.assertEqual(pdf.iloc[2]["table1/val"], "v3")

# Row 3: p3, 2020-01-03, <NA>
self.assertEqual(pdf.iloc[3]["patient_id"], "p3")
self.assertEqual(pdf.iloc[3]["timestamp"], pd.Timestamp("2020-01-03"))
self.assertTrue(pd.isna(pdf.iloc[3]["table1/val"]))

def test_empty_string_handling_composite_timestamp(self):
import os
from dataclasses import dataclass
from typing import List

# Create a temporary directory and a CSV file
with tempfile.TemporaryDirectory() as tmp_dir:
csv_path = os.path.join(tmp_dir, "data_composite.csv")
# Create CSV with empty strings in composite timestamp fields
# pid, year, month, day, val
# p1, 2020, 01, 01, v1 -> 2020-01-01
# p2, 2020, , 02, v2 -> missing month -> NaT
# p3, , 01, 03, v3 -> missing year -> NaT
with open(csv_path, "w") as f:
f.write("pid,year,month,day,val\n")
f.write("p1,2020,01,01,v1\n")
f.write("p2,2020,,02,v2\n")
f.write("p3,,01,03,v3\n")

@dataclass
class TableConfig:
file_path: str
patient_id: str
timestamp: List[str]
timestamp_format: str
attributes: List[str]
join: List = None

def __post_init__(self):
if self.join is None:
self.join = []

@dataclass
class Config:
tables: dict

config = Config(
tables={
"table1": TableConfig(
file_path="data_composite.csv",
patient_id="pid",
timestamp=["year", "month", "day"],
timestamp_format="%Y%m%d",
attributes=["val"]
)
}
)

class ConcreteDataset(BaseDataset):
pass

dataset = ConcreteDataset(
root=tmp_dir,
tables=["table1"],
dataset_name="TestDatasetComposite",
cache_dir=tmp_dir
)
dataset.config = config

# Load data
df = dataset.load_table("table1")
pdf = df.compute()

# Verify
# Row 0: p1, 2020-01-01
self.assertEqual(pdf.iloc[0]["patient_id"], "p1")
self.assertEqual(pdf.iloc[0]["timestamp"], pd.Timestamp("2020-01-01"))
self.assertEqual(pdf.iloc[0]["table1/val"], "v1")

# Row 1: p2, NaT
self.assertEqual(pdf.iloc[1]["patient_id"], "p2")
self.assertTrue(pd.isna(pdf.iloc[1]["timestamp"]))
self.assertEqual(pdf.iloc[1]["table1/val"], "v2")

# Row 2: p3, NaT
self.assertEqual(pdf.iloc[2]["patient_id"], "p3")
self.assertTrue(pd.isna(pdf.iloc[2]["timestamp"]))
self.assertEqual(pdf.iloc[2]["table1/val"], "v3")


if __name__ == "__main__":
unittest.main()