Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 67 additions & 40 deletions src/data/base_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import copy
import os
import time
from functools import partial
from typing import Any, Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
from lightning import LightningDataModule
from sklearn.cluster import DBSCAN
from sklearn.model_selection import GroupShuffleSplit
from torch.utils.data import DataLoader, random_split

Expand Down Expand Up @@ -48,8 +45,9 @@ def __init__(
:param save_split: if to save split file
:param saved_split_file_name: file name to save split file
:param caption_builder: instance of BaseCaptionBuilder for generating textual captions
:param spatial_split_distance_m: minimum distance in metres between clusters when
split_mode is 'spatial_clusters'. Default 1000 m.
:param spatial_split_distance_m: grid cell size in metres when split_mode is
'spatial_clusters'. Samples within the same cell are kept together and assigned to the
same split. Default 1000 m.
"""
super().__init__()
self.save_hyperparameters(logger=False)
Expand Down Expand Up @@ -122,33 +120,43 @@ def split_data(self) -> None:

elif self.hparams.split_mode == "spatial_clusters":
min_dist = self.hparams.spatial_split_distance_m
coords = np.array([self.dataset.df.lat, self.dataset.df.lon]).T
# Use records (not df): records is already filtered (e.g. missing tiles
# dropped), so len(records) <= len(df). Indices must be into records
# because __getitem__ and __len__ both operate on self.records.
# lat/lon come from df (always present) keyed by name_loc so the
# coordinate array stays aligned with records regardless of modalities.
records = self.dataset.records
_nl_to_coords = dict(
zip(
self.dataset.df["name_loc"],
zip(self.dataset.df["lat"], self.dataset.df["lon"]),
)
)
coords = np.array(
[
[_nl_to_coords[r["name_loc"]][0] for r in records],
[_nl_to_coords[r["name_loc"]][1] for r in records],
]
).T
n = len(coords)
# Grid-based spatial grouping: assign each sample to a geographic
# cell of size spatial_split_distance_m × spatial_split_distance_m.
# GroupShuffleSplit then distributes whole cells across splits, so
# geographically close samples stay together while split proportions
# remain balanced (unlike DBSCAN, which chain-links dense data into
# a few giant clusters and produces wildly uneven splits).
_METERS_PER_DEG_LAT = 111_000.0
lat_step = min_dist / _METERS_PER_DEG_LAT
lon_step = min_dist / (_METERS_PER_DEG_LAT * np.cos(np.radians(np.mean(coords[:, 0]))))
grid_ids = np.floor(coords[:, 0] / lat_step).astype(np.int64) * 1_000_000 + np.floor(
coords[:, 1] / lon_step
).astype(np.int64)
_, clusters = np.unique(grid_ids, return_inverse=True)
n_clusters = int(clusters.max()) + 1
print(
f"Splitting {n} samples into spatial clusters "
f"(eps={min_dist / 1000:.1f} km, haversine, n_jobs=-1)..."
f"Splitting {n} samples into {n_clusters} spatial grid cells "
f"(cell size ≈ {min_dist / 1000:.0f} km). Creating splits..."
)
# Convert (lat, lon) degrees to radians for sklearn's haversine metric.
# haversine returns arc length on the unit sphere, so eps must be in radians.
_EARTH_RADIUS_M = 6_371_000
coords_rad = np.radians(coords)
eps_rad = min_dist / _EARTH_RADIUS_M
t0 = time.time()
clustering = DBSCAN(
eps=eps_rad,
metric="haversine",
algorithm="ball_tree",
min_samples=2,
n_jobs=-1,
).fit(coords_rad)
print(f"DBSCAN done in {time.time() - t0:.1f}s. Creating splits...")
# Non-clustered points are labeled -1. Change to new cluster label.
clusters = copy.deepcopy(clustering.labels_)
new_cl = np.max(clusters) + 1
for i, cl in enumerate(clusters):
if cl == -1:
clusters[i] = new_cl
new_cl += 1

gss = GroupShuffleSplit(
n_splits=1,
Expand Down Expand Up @@ -197,13 +205,15 @@ def split_data(self) -> None:
)

print(
f"Created {len(train_indices)} train, {len(val_indices)} val, {len(test_indices)} test indices using DBSCAN spatial clustering with {min_dist} m minimum distance between clusters."
f"Created {len(train_indices)} train, {len(val_indices)} val, "
f"{len(test_indices)} test indices across {n_clusters} spatial grid cells "
f"(cell size ≈ {min_dist / 1000:.0f} km)."
)
if self.hparams.save_split:
split_indices = {
"train_indices": self.dataset.df.name_loc[train_indices],
"val_indices": self.dataset.df.name_loc[val_indices],
"test_indices": self.dataset.df.name_loc[test_indices],
"train_indices": pd.Series([records[i]["name_loc"] for i in train_indices]),
"val_indices": pd.Series([records[i]["name_loc"] for i in val_indices]),
"test_indices": pd.Series([records[i]["name_loc"] for i in test_indices]),
"clusters": clusters,
}

Expand Down Expand Up @@ -231,10 +241,21 @@ def split_data(self) -> None:
if test_indices is not None and not isinstance(test_indices, pd.Series):
raise NotImplementedError("Expected a pd series of name_locs for data splits.")

train_indices = np.where(self.dataset.df["name_loc"].isin(train_indices))[0]
val_indices = np.where(self.dataset.df["name_loc"].isin(val_indices))[0]
# Map name_locs → records-level indices (not df row indices).
# self.records may be shorter than self.df when records are
# dropped (e.g. missing tessera_prev tiles in config B), so
# df row indices would be out of range in __getitem__.
_name_loc_to_rec_idx = {r["name_loc"]: i for i, r in enumerate(self.dataset.records)}
train_indices = np.array(
[_name_loc_to_rec_idx[nl] for nl in train_indices if nl in _name_loc_to_rec_idx]
)
val_indices = np.array(
[_name_loc_to_rec_idx[nl] for nl in val_indices if nl in _name_loc_to_rec_idx]
)
if test_indices is not None:
test_indices = np.where(self.dataset.df["name_loc"].isin(test_indices))[0]
test_indices = np.array(
[_name_loc_to_rec_idx[nl] for nl in test_indices if nl in _name_loc_to_rec_idx]
)

print(f"Dataset was split using indices from file: {self.saved_split_file_path}")
else:
Expand Down Expand Up @@ -277,11 +298,14 @@ def _compute_tabular_normalisation_stats(self) -> None:
return

train_indices = self.data_train.indices
train_df = self.dataset.df.iloc[train_indices][feat_names]
train_df = pd.DataFrame(
[[self.dataset.records[i][k] for k in feat_names] for i in train_indices],
columns=feat_names,
)

mean = train_df.mean(axis=0).values
std = train_df.std(axis=0).values
std = np.where(std == 0, 1.0, std) # avoid division by zero for constant features
std = np.where((std == 0) | np.isnan(std), 1.0, std)

self.tabular_normalisation_stats = (
torch.tensor(mean, dtype=torch.float32),
Expand All @@ -304,11 +328,14 @@ def _compute_target_normalisation_stats(self) -> None:
return

train_indices = self.data_train.indices
train_df = self.dataset.df.iloc[train_indices][target_names]
train_df = pd.DataFrame(
[[self.dataset.records[i][k] for k in target_names] for i in train_indices],
columns=target_names,
)

mean = train_df.mean(axis=0).values
std = train_df.std(axis=0).values
std = np.where(std == 0, 1.0, std) # avoid division by zero for constant targets
std = np.where((std == 0) | np.isnan(std), 1.0, std)

self.target_normalisation_stats = (
torch.tensor(mean, dtype=torch.float32),
Expand Down
85 changes: 79 additions & 6 deletions src/data/yield_africa_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,12 @@ class YieldAfricaDataset(BaseDataset):

Modality design note
--------------------
`implemented_mod = {"coords"}` because tabular features live directly in
the model-ready CSV and are picked up via the `feat_` column prefix.
They do NOT need to be listed in `modalities`.
Tabular features live directly in the model-ready CSV and are picked up
via the `feat_` column prefix. They do NOT need to be listed in
`modalities`. Implemented spatial modalities: ``coords``, ``tessera``
(year-Y embedding), ``tessera_prev`` (year-Y−1 embedding for dual-year
fusion). Adding ``tessera_prev`` to modalities activates dual-year
loading; single-year runs are unaffected.

In addition to the CSV feat_* columns, the following features are injected:
- ``feat_year`` : normalised year (zero-mean, unit-std)
Expand Down Expand Up @@ -88,6 +91,7 @@ def __init__(
years: List[int] | None = None,
exclude_countries: List[str] | None = None,
exclude_years: List[int] | None = None,
require_prev_year_tessera: bool = True,
) -> None:
super().__init__(
data_dir=data_dir,
Expand All @@ -97,11 +101,12 @@ def __init__(
dataset_name="yield_africa",
seed=seed,
cache_dir=cache_dir,
implemented_mod={"coords", "tessera"},
implemented_mod={"coords", "tessera", "tessera_prev"},
mock=mock,
use_features=use_features,
csv_name=csv_name,
)
self.require_prev_year_tessera = require_prev_year_tessera

# Inject year and country one-hot columns as feat_* so that
# get_records() picks them up automatically. Build all new columns in
Expand Down Expand Up @@ -144,6 +149,23 @@ def __init__(

self.df = pd.concat([self.df, pd.DataFrame(new_cols, index=self.df.index)], axis=1)

# Build a cross-year tessera path index from the full unfiltered df.
# Must happen before the country/year filter below so that year-Y records
# can resolve year-Y−1 paths even when those rows are excluded by a
# years= filter. Keys: (lat_rounded, lon_rounded, year) → path.
if "tessera_prev" in self.modalities:
_tessera_dir_full = os.path.join(self.data_dir, "eo", "tessera")
_year_path_index: dict[tuple[float, float, int], str] = {}
_name_loc_coords: dict[str, tuple[float, float, int]] = {}
for _, _r in self.df.iterrows():
_lat_r = round(float(_r["lat"]), 6)
_lon_r = round(float(_r["lon"]), 6)
_year_r = int(_r["year"])
_year_path_index[(_lat_r, _lon_r, _year_r)] = os.path.join(
_tessera_dir_full, f"tessera_{_r['name_loc']}_{_year_r}.npy"
)
_name_loc_coords[_r["name_loc"]] = (_lat_r, _lon_r, _year_r)

# Apply country/year filters to self.df and rebuild records.
# BaseDataset.__init__ has already loaded the CSV; filtering here avoids
# touching BaseDataset and keeps the logic use-case specific.
Expand Down Expand Up @@ -180,6 +202,21 @@ def __init__(
# self.feat_names and self.tabular_dim.
self.records = self.get_records()

# Rewrite tessera paths to the year-suffixed convention
# (tessera_{name_loc}_{year}.npy). BaseDataset.add_modality_paths_to_df()
# generates paths without a year; this override is local to
# YieldAfricaDataset and leaves BaseDataset unchanged.
if "tessera" in self.modalities:
_tessera_dir = os.path.join(self.data_dir, "eo", "tessera")
_name_loc_to_year: dict[str, int] = dict(
zip(self.df["name_loc"], self.df["year"].astype(int))
)
for rec in self.records:
year = _name_loc_to_year[rec["name_loc"]]
rec["tessera_path"] = os.path.join(
_tessera_dir, f"tessera_{rec['name_loc']}_{year}.npy"
)

# Drop records whose TESSERA tile is absent so the model is never
# trained or evaluated on zero-padded stand-ins.
if "tessera" in self.modalities:
Expand All @@ -193,6 +230,38 @@ def __init__(
before,
)

# Resolve tessera_prev_path for each record using the cross-year index.
# Records whose year-1 tile is absent are dropped when
# require_prev_year_tessera=True (default), or retained with
# tessera_prev_path=None when False.
if "tessera_prev" in self.modalities:
resolved = []
for rec in self.records:
lat_r, lon_r, year_r = _name_loc_coords[rec["name_loc"]]
key = (lat_r, lon_r, year_r - 1)
prev_path = _year_path_index.get(key)
if prev_path is not None and os.path.exists(prev_path):
resolved.append({**rec, "tessera_prev_path": prev_path})
else:
# Fall back to synthetic tile produced by --include-prev-year:
# tessera_{name_loc}_prev_{year-1}.npy
synth_path = os.path.join(
_tessera_dir_full,
f"tessera_{rec['name_loc']}_prev_{year_r - 1}.npy",
)
if os.path.exists(synth_path):
resolved.append({**rec, "tessera_prev_path": synth_path})
elif not self.require_prev_year_tessera:
resolved.append({**rec, "tessera_prev_path": None})
dropped = len(self.records) - len(resolved)
if dropped:
log.warning(
"Dropped %d/%d records: no year-1 TESSERA tile found.",
dropped,
len(self.records),
)
self.records = resolved

def setup(self) -> None:
"""Check for requested modality data; warn if TESSERA tiles are absent.

Expand All @@ -204,13 +273,14 @@ def setup(self) -> None:
single fixed year for bulk download, which is incompatible with the
multi-year nature of this dataset.
"""
if "tessera" in self.modalities:
if "tessera" in self.modalities or "tessera_prev" in self.modalities:
tessera_dir = os.path.join(self.data_dir, "eo", "tessera")
if not os.path.exists(tessera_dir) or len(os.listdir(tessera_dir)) == 0:
log.warning(
"TESSERA tiles not found at %s. "
"Run src/data_preprocessing/yield_africa_tessera_preprocess.py "
"to pre-fetch tiles. Records with missing tiles are excluded from the dataset.",
"to pre-fetch tiles. For tessera_prev, also pass --include-prev-year. "
"Records with missing tiles are excluded from the dataset.",
tessera_dir,
)

Expand All @@ -226,6 +296,9 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
)
elif modality == "tessera":
sample["eo"]["tessera"] = self.load_tessera(row["tessera_path"])
elif modality == "tessera_prev":
if row.get("tessera_prev_path") is not None:
sample["eo"]["tessera_prev"] = self.load_tessera(row["tessera_prev_path"])

if self.use_features and self.feat_names:
sample["eo"]["tabular"] = torch.tensor(
Expand Down
Loading
Loading