From d993e873055f3944e3526a4556b1ff7679861d1a Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 25 Mar 2026 21:43:13 +0100 Subject: [PATCH 01/78] Log dynamic k per concept --- src/models/text_alignment_model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 49b8a91..ad94fd5 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -117,6 +117,10 @@ def setup_encoders_adapters(self): def setup_retrieval_evaluation(self): self.concept_configs = self.trainer.datamodule.concept_configs self.concepts = [c["concept_caption"] for c in self.concept_configs] + self.concept_names = [ + f"{c['col'].replace('aux_', '')}_{'max' if c['is_max'] else 'min'}" + for c in self.concept_configs + ] self.contrastive_val = RetrievalContrastiveValidation(self.ks, self.concept_configs) self.outputs_epoch_memory = [] @@ -214,6 +218,8 @@ def _on_epoch_end(self, mode: str): for i, result in concept_scores.items(): print(f'\nConcept "{self.concepts[i]}" average top-k accuracies in {mode} split:') for k, v in result.items(): + if k == "dynamic_k": + self.log(f"dyn_k_{self.self.concept_names[i]}", v, **self.log_kwargs) print(f"Top-{k}: {v:.1f}%") avr_scores[f"{mode}_avr_top-{k}"].append(v) From 3122086c99d45872f9ece98397d2007f27c5f714 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 25 Mar 2026 21:43:23 +0100 Subject: [PATCH 02/78] valuation configs --- configs/eval.yaml | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/configs/eval.yaml b/configs/eval.yaml index be31299..80436cd 100644 --- a/configs/eval.yaml +++ b/configs/eval.yaml @@ -1,18 +1,43 @@ # @package _global_ +# specify here default configuration +# order of defaults determines the order in which configs override each other defaults: - _self_ - - data: mnist # choose datamodule with `test_dataloader()` for evaluation - - model: mnist - - logger: null - - trainer: default - - paths: default + - data: butterfly_coords + - model: predictive_geoclip + - callbacks: default + - logger: ${oc.env:LOGGER,wandb} + - trainer: ${oc.env:TRAINER_PROFILE,default} + - paths: ${oc.env:STORAGE_MODE,local} - extras: default - hydra: default + - metrics: butterfly_predictive + # experiment configs allow for version control of specific hyperparameters + # e.g. best hyperparameters for given model and datamodule + - experiment: null + + # config for hyperparameter optimization + - hparams_search: null + + # optional local config for machine/user specific settings + # it's optional since it doesn't need to exist and is excluded from version control + - optional local: default + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: null + +# task name, determines output directory path task_name: "eval" +# tags to help you identify your experiments +# you can overwrite this in experiment configs +# overwrite from command line with `python train.py tags="[first_tag, second_tag]"` tags: ["dev"] +# seed for random number generators in pytorch, numpy and python.random +seed: 12345 + # passing checkpoint path is necessary for evaluation -ckpt_path: ??? +ckpt_path: aether/logs/train/runs/2026-03-15_13-15-08/checkpoints/epoch_098.ckpt From a9b443eb4cc54fba9a49663b58d8b5e618a3e38e Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 25 Mar 2026 21:44:25 +0100 Subject: [PATCH 03/78] alignment poster inspection code --- notebooks/08-GT-aligment-visualisation.ipynb | 330 +++++++++++++++++++ 1 file changed, 330 insertions(+) create mode 100644 notebooks/08-GT-aligment-visualisation.ipynb diff --git a/notebooks/08-GT-aligment-visualisation.ipynb b/notebooks/08-GT-aligment-visualisation.ipynb new file mode 100644 index 0000000..f8b4e9d --- /dev/null +++ b/notebooks/08-GT-aligment-visualisation.ipynb @@ -0,0 +1,330 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from typing import Any, Dict, List, Tuple\n", + "\n", + "import hydra\n", + "import torch\n", + "from lightning import LightningDataModule, LightningModule, Trainer\n", + "from lightning.pytorch.loggers import Logger\n", + "\n", + "if os.environ.get(\"TOKENIZERS_PARALLELISM\") is None:\n", + " os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "\n", + "from src.utils import (\n", + " instantiate_loggers,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "os.chdir(\"..\")\n", + "os.getcwd()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "from hydra import compose, initialize\n", + "from hydra.core.hydra_config import HydraConfig\n", + "\n", + "with initialize(version_base=\"1.3\", config_path=\"../configs\"):\n", + " cfg = compose(\n", + " config_name=\"eval.yaml\",\n", + " return_hydra_config=True,\n", + " overrides=[\"experiment=alignment_v1\"], # 👈 loads configs/experiment/my_exp.yaml\n", + " )\n", + "\n", + "HydraConfig.instance().set_config(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "assert cfg.ckpt_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)\n", + "\n", + "model: LightningModule = hydra.utils.instantiate(cfg.model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "logger: List[Logger] = instantiate_loggers(cfg.get(\"logger\"))\n", + "\n", + "cfg[\"trainer\"][\"max_epochs\"] = 0\n", + "cfg[\"trainer\"][\"accelerator\"] = \"mps\"\n", + "trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)\n", + "\n", + "object_dict = {\n", + " \"cfg\": cfg,\n", + " \"datamodule\": datamodule,\n", + " \"model\": model,\n", + " \"logger\": logger,\n", + " \"trainer\": trainer,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.strategy.connect(model)\n", + "trainer._data_connector.attach_data(model, datamodule=datamodule)\n", + "\n", + "model.setup(stage=\"fit\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "path = \"logs/train/runs/2026-03-15_13-15-08/checkpoints/epoch_098.ckpt\"\n", + "ckpt = torch.load(path, map_location=\"cpu\", weights_only=False)\n", + "model.load_state_dict(ckpt[\"state_dict\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "data = datamodule.test_dataloader().dataset[2]\n", + "data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "\n", + "def show_rgb(tensor, stretch=True):\n", + " \"\"\"\n", + " tensor: numpy array or torch tensor of shape (C, H, W)\n", + " bands: tuple of 3 indices (R, G, B)\n", + " stretch: whether to normalize values for display\n", + " \"\"\"\n", + " # Convert to numpy if it's a torch tensor\n", + " if \"torch\" in str(type(tensor)):\n", + " tensor = tensor.detach().cpu().numpy()\n", + "\n", + " # Select bands\n", + " rgb = tensor[:3, :, :]\n", + "\n", + " # Move to (H, W, 3)\n", + " rgb = np.transpose(rgb, (1, 2, 0))\n", + "\n", + " # Normalize for visualization\n", + " if stretch:\n", + " rgb = rgb.astype(np.float32)\n", + " for i in range(3):\n", + " band = rgb[:, :, i]\n", + " band_min, band_max = band.min(), band.max()\n", + " if band_max > band_min:\n", + " rgb[:, :, i] = (band - band_min) / (band_max - band_min)\n", + "\n", + " plt.imshow(rgb)\n", + " plt.axis(\"off\")\n", + " plt.show()\n", + "\n", + "\n", + "show_rgb(data[\"eo\"][\"aef\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "batch = data\n", + "batch[\"eo\"][\"aef\"] = batch[\"eo\"][\"aef\"].unsqueeze(0)\n", + "batch[\"eo\"][\"aef\"].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "geo_feats = model.geo_encoder(data).squeeze(1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "concepts = [\n", + " \"Densely populated area with many houses\",\n", + " \"Very sparsely populated area with few houses\",\n", + " \"Arable land with crops for agriculture\",\n", + " \"Pasture fields with grass for grazing animals\",\n", + " \"Forested area with many trees\",\n", + " \"Water bodies such as lakes, rivers and sea\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "text_tokens = model.text_encoder.processor(\n", + " text=concepts,\n", + " return_tensors=\"pt\",\n", + " padding=True,\n", + " truncation=True,\n", + " max_length=77,\n", + ")\n", + "\n", + "device = model.text_encoder.device\n", + "text_tokens = {k: v.to(device) for k, v in text_tokens.items()}\n", + "\n", + "text_embeds = model.text_encoder.model.get_text_features(**text_tokens)\n", + "\n", + "# Project\n", + "if model.text_encoder.projector is not None:\n", + " text_embeds = model.text_encoder.projector(text_embeds)\n", + "\n", + "if model.text_encoder.extra_projector is not None:\n", + " text_embeds = model.text_encoder.extra_projector(text_embeds)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "text_embeds.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "def cosine_scores(x, Y):\n", + " \"\"\"\n", + " x: tensor of shape (1, 512)\n", + " Y: tensor of shape (6, 512)\n", + " returns: tensor of shape (6,)\n", + " \"\"\"\n", + " # Normalize\n", + " x = F.normalize(x, dim=1) # (1, 512)\n", + " Y = F.normalize(Y, dim=1) # (6, 512)\n", + "\n", + " # Compute cosine similarity\n", + " scores = torch.matmul(Y, x.T) # (6, 1)\n", + "\n", + " return scores.squeeze(1)\n", + "\n", + "\n", + "cosine_scores(geo_feats, text_embeds.cpu())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "scores = {\n", + " k: round(v.detach().item(), 2)\n", + " for k, v in zip(concepts, cosine_scores(geo_feats, text_embeds.cpu()))\n", + "}\n", + "scores" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 9d322776ecf2d6bc1da3e450319d4992a67e2b6f Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 31 Mar 2026 13:31:40 +0200 Subject: [PATCH 04/78] fix self.self --- src/models/text_alignment_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index ad94fd5..2cc1048 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -219,7 +219,7 @@ def _on_epoch_end(self, mode: str): print(f'\nConcept "{self.concepts[i]}" average top-k accuracies in {mode} split:') for k, v in result.items(): if k == "dynamic_k": - self.log(f"dyn_k_{self.self.concept_names[i]}", v, **self.log_kwargs) + self.log(f"dyn_k_{self.concept_names[i]}", v, **self.log_kwargs) print(f"Top-{k}: {v:.1f}%") avr_scores[f"{mode}_avr_top-{k}"].append(v) From d2b4398e0db2e4de9eda6e530626c0c74c6fbf50 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 31 Mar 2026 13:32:09 +0200 Subject: [PATCH 05/78] Remove redundant val_val in logging the metric names --- src/models/components/metrics/metrics_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/components/metrics/metrics_wrapper.py b/src/models/components/metrics/metrics_wrapper.py index 33d86de..0d40212 100644 --- a/src/models/components/metrics/metrics_wrapper.py +++ b/src/models/components/metrics/metrics_wrapper.py @@ -19,6 +19,6 @@ def forward(self, mode="train", **kwargs) -> Dict[str, torch.float]: for metric in self.metrics: metric_results = metric(mode=mode, return_label=True, **kwargs) for k, v in metric_results.items(): - compiled_dict[f"{mode}_{k}"] = v + compiled_dict[k] = v return compiled_dict From beccf79052dc86d2ccdcb6b716b14bb7efb020e4 Mon Sep 17 00:00:00 2001 From: Thijs van der Plas Date: Tue, 31 Mar 2026 22:00:16 +0200 Subject: [PATCH 06/78] Index dynamic top k -- not tested yet --- src/models/text_alignment_model.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index ad94fd5..3c95fbc 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -122,6 +122,29 @@ def setup_retrieval_evaluation(self): for c in self.concept_configs ] + dataset_names = ["train", "val", "test"] + self.dynamic_k_baselines = {} + for dataset_name in dataset_names: + if not hasattr(self.trainer.datamodule, f"{dataset_name}_dataloader"): + continue + tmp_ds = getattr(self.trainer.datamodule, f"{dataset_name}_dataloader")().dataset + n_ds = len(tmp_ds) + self.dynamic_k_baselines[dataset_name] = {} + for i_c, c in enumerate(self.concept_configs): + c_name = self.concept_names[i_c] + aux_col_id = c["id"] + if c["is_max"]: + n_baseline = sum( + tmp_ds[ii]["aux"][aux_col_id] >= c.get("theta_k", float("inf")) + for ii in range(len(tmp_ds)) + ) + else: + n_baseline = sum( + tmp_ds[ii]["aux"][aux_col_id] <= c.get("theta_k", float("inf")) + for ii in range(len(tmp_ds)) + ) + self.dynamic_k_baselines[dataset_name][c_name] = n_baseline / n_ds * 100 + self.contrastive_val = RetrievalContrastiveValidation(self.ks, self.concept_configs) self.outputs_epoch_memory = [] @@ -220,6 +243,10 @@ def _on_epoch_end(self, mode: str): for k, v in result.items(): if k == "dynamic_k": self.log(f"dyn_k_{self.self.concept_names[i]}", v, **self.log_kwargs) + indexed_v = (v - self.dynamic_k_baselines[mode][self.concept_names[i]]) / ( + 100 - self.dynamic_k_baselines[mode][self.concept_names[i]] + ) + self.log(f"dyn_k_index_{self.concept_names[i]}", indexed_v, **self.log_kwargs) print(f"Top-{k}: {v:.1f}%") avr_scores[f"{mode}_avr_top-{k}"].append(v) From 23e8b28c783b4548d52ffdbc992658059231e94a Mon Sep 17 00:00:00 2001 From: Thijs van der Plas Date: Wed, 1 Apr 2026 08:47:54 +0200 Subject: [PATCH 07/78] Elbow method for theta k -- not tested yet --- src/models/text_alignment_model.py | 33 ++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 3c95fbc..5d20d94 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -1,6 +1,7 @@ from io import text_encoding from typing import Dict, Tuple, override +import numpy as np import torch import torch.nn.functional as F @@ -133,16 +134,13 @@ def setup_retrieval_evaluation(self): for i_c, c in enumerate(self.concept_configs): c_name = self.concept_names[i_c] aux_col_id = c["id"] + aux_vals_current_ds = [tmp_ds[ii]["aux"][aux_col_id] for ii in range(len(tmp_ds))] + # theta_k = c['theta_k'] + theta_k = self.find_elbow_point(aux_vals_current_ds) if c["is_max"]: - n_baseline = sum( - tmp_ds[ii]["aux"][aux_col_id] >= c.get("theta_k", float("inf")) - for ii in range(len(tmp_ds)) - ) + n_baseline = sum(aux_val >= theta_k for aux_val in aux_vals_current_ds) else: - n_baseline = sum( - tmp_ds[ii]["aux"][aux_col_id] <= c.get("theta_k", float("inf")) - for ii in range(len(tmp_ds)) - ) + n_baseline = sum(aux_val <= theta_k for aux_val in aux_vals_current_ds) self.dynamic_k_baselines[dataset_name][c_name] = n_baseline / n_ds * 100 self.contrastive_val = RetrievalContrastiveValidation(self.ks, self.concept_configs) @@ -287,3 +285,22 @@ def concept_similarities(self, geo_embeds, concept=None) -> torch.Tensor: similarity_matrix = concept_embeds @ geo_embeds.T return similarity_matrix + + def find_elbow_point(vals): + vals = np.sort(vals) + x = np.arange(len(vals)) / len(vals) + y = vals + slope = (y[-1] - y[0]) / (x[-1] - x[0]) # diagonal from first to last point + intercept = y[0] - slope * x[0] + orthogonal_slope = -1 / slope + + intercepts_orthogonal = y - orthogonal_slope * x + intersection_diagonal_orthogonal = (intercepts_orthogonal - intercept) / ( + slope - orthogonal_slope + ) + distances = np.sqrt( + (x - intersection_diagonal_orthogonal) ** 2 + (y - (slope * x + intercept)) ** 2 + ) # distance to diagonal + elbow_index = np.argmax(distances) + elbow_point = y[elbow_index] + return elbow_point From ce5232c914ca8545a8ef550247e940d4cd9b523a Mon Sep 17 00:00:00 2001 From: Thijs van der Plas Date: Wed, 1 Apr 2026 18:48:54 +0200 Subject: [PATCH 08/78] update theta_k with calculated --- src/models/text_alignment_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 5d20d94..074e7d6 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -137,6 +137,9 @@ def setup_retrieval_evaluation(self): aux_vals_current_ds = [tmp_ds[ii]["aux"][aux_col_id] for ii in range(len(tmp_ds))] # theta_k = c['theta_k'] theta_k = self.find_elbow_point(aux_vals_current_ds) + self.concept_configs[i_c][ + "theta_k" + ] = theta_k # assign new theta_k to concept_configs for later use in validation if c["is_max"]: n_baseline = sum(aux_val >= theta_k for aux_val in aux_vals_current_ds) else: From 150f4f65059626dc2c9e4896fe2727fca1d9c1ff Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 2 Apr 2026 13:21:58 +0200 Subject: [PATCH 09/78] Add missing pass through the extra projector --- src/models/components/geo_encoders/average_encoder.py | 2 ++ src/models/components/geo_encoders/cnn_encoder.py | 3 +++ src/models/components/geo_encoders/geoclip.py | 2 ++ src/models/components/geo_encoders/mlp_projector.py | 5 ++++- src/models/components/geo_encoders/tabular_encoder.py | 3 +++ 5 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/models/components/geo_encoders/average_encoder.py b/src/models/components/geo_encoders/average_encoder.py index 0c0eaf6..b11e3ce 100644 --- a/src/models/components/geo_encoders/average_encoder.py +++ b/src/models/components/geo_encoders/average_encoder.py @@ -43,4 +43,6 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """Data forward pass through the encoder.""" tile = batch.get("eo", {}).get(self.geo_data_name) feats = self.geo_encoder(tile.mean(dim=(-2, -1))) + if self.extra_projector: + feats = self.extra_projector(feats) return feats diff --git a/src/models/components/geo_encoders/cnn_encoder.py b/src/models/components/geo_encoders/cnn_encoder.py index 34f8f48..6ec3609 100644 --- a/src/models/components/geo_encoders/cnn_encoder.py +++ b/src/models/components/geo_encoders/cnn_encoder.py @@ -159,6 +159,9 @@ def forward( # n_nans == 0 # ), f"CNNEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.geo_data_name].min()} and max {eo_data[self.geo_data_name].max()}." + if self.extra_projector: + feats = self.extra_projector(feats) + return feats.to(dtype) diff --git a/src/models/components/geo_encoders/geoclip.py b/src/models/components/geo_encoders/geoclip.py index 38dffdb..cd246ef 100644 --- a/src/models/components/geo_encoders/geoclip.py +++ b/src/models/components/geo_encoders/geoclip.py @@ -39,6 +39,8 @@ def forward( if coords.dtype != dtype: coords = coords.to(dtype) feats = self.geo_encoder(coords) + if self.extra_projector: + feats = self.extra_projector(feats) return feats.to(dtype) diff --git a/src/models/components/geo_encoders/mlp_projector.py b/src/models/components/geo_encoders/mlp_projector.py index e622216..6c27345 100644 --- a/src/models/components/geo_encoders/mlp_projector.py +++ b/src/models/components/geo_encoders/mlp_projector.py @@ -49,4 +49,7 @@ def configure_nn(self) -> None: self.net = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.net(x) + feats = self.net(x) + if self.extra_projector: + feats = self.extra_projector(feats) + return feats diff --git a/src/models/components/geo_encoders/tabular_encoder.py b/src/models/components/geo_encoders/tabular_encoder.py index 1ae4b1d..314cf29 100644 --- a/src/models/components/geo_encoders/tabular_encoder.py +++ b/src/models/components/geo_encoders/tabular_encoder.py @@ -68,4 +68,7 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: tab_data = tab_data.to(dtype) feats = self.geo_encoder(tab_data) + if self.extra_projector: + feats = self.extra_projector(feats) + return feats.to(dtype) From 83ef95d9ea03a4901c2cf8d470d6613ec9e69318 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 2 Apr 2026 13:23:09 +0200 Subject: [PATCH 10/78] Add new type of error for missing file specification --- src/utils/errors.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/utils/errors.py b/src/utils/errors.py index 910dea5..aa88ce0 100644 --- a/src/utils/errors.py +++ b/src/utils/errors.py @@ -1,2 +1,10 @@ class IllegalArgumentCombination(ValueError): + """Error for illogical argument configuration.""" + + pass + + +class FileNotSpecified(ValueError): + """Error for missing file path specification.""" + pass From de7da4742819e8de8ade36732a8e0f0f89832913 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 2 Apr 2026 13:23:40 +0200 Subject: [PATCH 11/78] Add missing pass through the extra projector --- src/models/components/geo_encoders/encoder_wrapper.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/models/components/geo_encoders/encoder_wrapper.py b/src/models/components/geo_encoders/encoder_wrapper.py index 9cdf8ad..0523619 100644 --- a/src/models/components/geo_encoders/encoder_wrapper.py +++ b/src/models/components/geo_encoders/encoder_wrapper.py @@ -122,8 +122,14 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: branch_feats.append(feats) if self.fusion_strategy == "concat": - return torch.cat(branch_feats, dim=1) - return torch.mean(branch_feats, dim=1) + feats = torch.cat(branch_feats, dim=1) + if self.extra_projector: + feats = self.extra_projector(feats) + else: + feats = torch.cat(branch_feats, dim=1) + if self.extra_projector: + feats = self.extra_projector(feats) + return feats @property def device(self): From 8c4b2ab43b55d026cfcc398e51ac78643f075663 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 2 Apr 2026 15:25:42 +0200 Subject: [PATCH 12/78] fix if self.extra_projector missing --- src/models/components/geo_encoders/base_geo_encoder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/models/components/geo_encoders/base_geo_encoder.py b/src/models/components/geo_encoders/base_geo_encoder.py index f79f3f4..6b65c75 100644 --- a/src/models/components/geo_encoders/base_geo_encoder.py +++ b/src/models/components/geo_encoders/base_geo_encoder.py @@ -14,6 +14,7 @@ def __init__(self) -> None: # placeholders self.allowed_geo_data_names: list[str] | None = None self.geo_data_name: str | None = None + self.extra_projector: nn.Module | None = None @abstractmethod def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: From 047097ce6d82550eb0e854c8085a1a69a67a856f Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:09:43 +0200 Subject: [PATCH 13/78] Formating hook changes --- .../yield_africa_spatial_splits.py | 13 +++++-------- src/models/components/loss_fns/rrmse_loss.py | 12 +++++------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/data_preprocessing/yield_africa_spatial_splits.py b/src/data_preprocessing/yield_africa_spatial_splits.py index 4c9aa98..d71b15b 100644 --- a/src/data_preprocessing/yield_africa_spatial_splits.py +++ b/src/data_preprocessing/yield_africa_spatial_splits.py @@ -79,13 +79,12 @@ def make_spatial_split( """Return a split-indices dict using DBSCAN spatial clustering. :param df: full model-ready dataframe (must contain 'lat', 'lon', 'name_loc') - :param distance_m: DBSCAN eps in metres — pairs of fields closer than this - value are assigned to the same cluster + :param distance_m: DBSCAN eps in metres — pairs of fields closer than this value are assigned + to the same cluster :param train_val_test_split: (train, val, test) proportions, must sum to 1.0 :param seed: random seed for GroupShuffleSplit - :return: dict with 'train_indices', 'val_indices', 'test_indices' as - pd.Series of name_loc strings, plus 'clusters' as a numpy array of - cluster labels (same length as df) + :return: dict with 'train_indices', 'val_indices', 'test_indices' as pd.Series of name_loc + strings, plus 'clusters' as a numpy array of cluster labels (same length as df) """ # Deduplicate to unique (lat, lon) locations before clustering. # yield_africa has ~9 rows per location (one per year); running DBSCAN on all @@ -271,9 +270,7 @@ def generate_splits( f"(train={n_train}, val={n_val}, test={n_test}, " f"total={n_train + n_val + n_test}/{len(df)})" ) - log.info( - f" {dist_km}km: train={n_train}, val={n_val}, test={n_test} -> {out_name}" - ) + log.info(f" {dist_km}km: train={n_train}, val={n_val}, test={n_test} -> {out_name}") def main() -> None: diff --git a/src/models/components/loss_fns/rrmse_loss.py b/src/models/components/loss_fns/rrmse_loss.py index 1720f1b..8c62f7e 100644 --- a/src/models/components/loss_fns/rrmse_loss.py +++ b/src/models/components/loss_fns/rrmse_loss.py @@ -8,15 +8,13 @@ class RRMSELoss(BaseLossFn): """Relative Root Mean Squared Error (RRMSE). - RRMSE = RMSE / mean(|labels|) + RRMSE = RMSE / mean(labels) - Normalises RMSE by the mean absolute value of the target, giving a - unit-free percentage error. This makes results comparable across crops - and regions with different absolute yield scales (e.g. t/ha ranges - differ significantly between maize in Zambia and rice in Rwanda). + Normalises RMSE by the mean absolute value of the target, giving a unit-free percentage error. + This makes results comparable across crops and regions with different absolute yield scales + (e.g. t/ha ranges differ significantly between maize in Zambia and rice in Rwanda). - Returns a fraction (e.g. 0.15 = 15 % error). Multiply by 100 for - percentage when reporting. + Returns a fraction (e.g. 0.15 = 15 % error). Multiply by 100 for percentage when reporting. """ def __init__(self) -> None: From fc832d44195397009455a3353da278dc65f2731a Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:09:57 +0200 Subject: [PATCH 14/78] Fix hugging face dir --- configs/paths/shared.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/paths/shared.yaml b/configs/paths/shared.yaml index a4702b0..44218a8 100644 --- a/configs/paths/shared.yaml +++ b/configs/paths/shared.yaml @@ -21,4 +21,4 @@ work_dir: ${hydra:runtime.cwd} # huggingface cache directory # can be overridden via HF_HOME environment variable -huggingface_cache: ${oc.env:HF_HOME,oc.env:SHARED_CACHE/huggingface} +huggingface_cache: ${oc.env:HF_HOME,${oc.env:SHARED_CACHE}/huggingface} From 10030a755514045f43b334ac1f41d236535fc4df Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:11:28 +0200 Subject: [PATCH 15/78] Prediction head initialisation print statements --- src/models/components/geo_encoders/mlp_projector.py | 2 +- src/models/components/pred_heads/linear_pred_head.py | 1 + src/models/components/pred_heads/mlp_pred_head.py | 1 + src/models/components/pred_heads/mlp_regression_head.py | 2 ++ 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/models/components/geo_encoders/mlp_projector.py b/src/models/components/geo_encoders/mlp_projector.py index 6c27345..4ab6e05 100644 --- a/src/models/components/geo_encoders/mlp_projector.py +++ b/src/models/components/geo_encoders/mlp_projector.py @@ -27,7 +27,7 @@ def __init__( @override def setup(self) -> List[str]: self.configure_nn() - print("Model setup with MLP projector") + print("Model set up with MLP projector") return ["net"] def set_input_dim(self, input_dim: int) -> None: diff --git a/src/models/components/pred_heads/linear_pred_head.py b/src/models/components/pred_heads/linear_pred_head.py index 94338a7..1dfa094 100644 --- a/src/models/components/pred_heads/linear_pred_head.py +++ b/src/models/components/pred_heads/linear_pred_head.py @@ -37,6 +37,7 @@ def setup(self) -> None: assert type(self.input_dim) is int, self.input_dim assert type(self.output_dim) is int, self.output_dim self.net = nn.Linear(self.input_dim, self.output_dim) + print("Model set up with linear prediction head") return diff --git a/src/models/components/pred_heads/mlp_pred_head.py b/src/models/components/pred_heads/mlp_pred_head.py index 6d4c124..65610a5 100644 --- a/src/models/components/pred_heads/mlp_pred_head.py +++ b/src/models/components/pred_heads/mlp_pred_head.py @@ -50,6 +50,7 @@ def setup(self) -> None: input_dim = self.hidden_dim layers.append(nn.Linear(input_dim, self.output_dim)) self.net = nn.Sequential(*layers) + print("Model set up with MLP prediction head") return diff --git a/src/models/components/pred_heads/mlp_regression_head.py b/src/models/components/pred_heads/mlp_regression_head.py index d9553ec..9679fa2 100644 --- a/src/models/components/pred_heads/mlp_regression_head.py +++ b/src/models/components/pred_heads/mlp_regression_head.py @@ -70,3 +70,5 @@ def setup(self) -> None: layers.append(nn.Linear(in_dim, self.output_dim)) self.net = nn.Sequential(*layers) + print("Model set up with MLP regression prediction head") + return From d4bbc9be105caf389fa9875aabc86702c04ce4f1 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:11:45 +0200 Subject: [PATCH 16/78] Add full freezer --- src/models/base_model.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/models/base_model.py b/src/models/base_model.py index 89fbaf6..7cb7296 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -42,6 +42,18 @@ def setup(self, stage: str) -> None: called after trainer is initialized and datamodule is available.""" pass + @final + def full_freezer(self): + """Freeze the whole network.""" + # Freeze the whole network + for name, param in self.named_parameters(): + param.requires_grad = False + + for name, module in self.named_modules(): + module.eval() + + return + @final def freezer(self) -> None: """Freezes modules based on provided trainable modules.""" From 6c2e38d2d9f2c5a996db1f91a08acc1a3765659a Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:12:53 +0200 Subject: [PATCH 17/78] Add configuration saving into ckpg This is needed for the inference model re-configuration --- src/models/base_model.py | 37 +++++++++++++++++++---- src/models/predictive_model.py | 26 ++++++++++++---- src/models/text_alignment_model.py | 48 +++++++++++++++--------------- 3 files changed, 75 insertions(+), 36 deletions(-) diff --git a/src/models/base_model.py b/src/models/base_model.py index 7cb7296..1a6a2b5 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -16,6 +16,8 @@ def __init__( scheduler: torch.optim.lr_scheduler, loss_fn: BaseLossFn, metrics: MetricsWrapper, + num_classes: int | None = None, + tabular_dim: int | None = None, ) -> None: """Interface for any model. @@ -24,17 +26,29 @@ def __init__( :param scheduler: scheduler for the model weight update :param loss_fn: loss function :param metrics: metrics to track for model performance estimation + :param num_classes: number of target classes + :param tabular_dim: number of tabular features """ super().__init__() self.save_hyperparameters( - ignore=["loss_fn", "geo_encoder", "prediction_head", "text_encoder", "metrics"] + ignore=[ + "loss_fn", + "geo_encoder", + "prediction_head", + "text_encoder", + "metrics", + "optimizer", + "scheduler", + ] ) self.trainable_modules = trainable_modules - self.num_classes: int | None = None - self.tabular_dim: int | None = None + self.num_classes = num_classes + self.tabular_dim = tabular_dim self.loss_fn = loss_fn self.metrics = metrics + self.optimizer = optimizer + self.scheduler = scheduler @abstractmethod def setup(self, stage: str) -> None: @@ -139,10 +153,10 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Ten def configure_optimizers(self) -> Dict[str, Any]: """Configure optimizer and learning rate scheduler.""" - optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + optimizer = self.optimizer(params=self.trainer.model.parameters()) - if self.hparams.scheduler is not None: - scheduler = self.hparams.scheduler(optimizer=optimizer) + if self.scheduler is not None: + scheduler = self.scheduler(optimizer=optimizer) return { "optimizer": optimizer, "lr_scheduler": { @@ -162,6 +176,14 @@ def on_save_checkpoint(self, checkpoint): if any(k.startswith(part) for part in self.trainable_modules) } + checkpoint["hyper_parameters"].update( + { + "num_classes": self.num_classes, + "tabular_dim": self.tabular_dim, + "trainable_modules": self.trainable_modules, + } + ) + def on_load_checkpoint(self, checkpoint): """Load only trainable parts of the model.""" missing_keys, unexpected_keys = self.load_state_dict( @@ -169,6 +191,9 @@ def on_load_checkpoint(self, checkpoint): ) print("Model loaded from a checkpoint.") + # self.tabular_dim = checkpoint['hyper_parameters']["tabular_dim"] + # self.num_classes = checkpoint["hyper_parameters"]["num_classes"] + if missing_keys: missing_keys = {".".join(i.split(".")[:3]) for i in missing_keys} print(f"The following keys are missing from the pretrained model: {missing_keys}") diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index 7b951c2..6e5cb90 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -25,6 +25,8 @@ def __init__( loss_fn: BaseLossFn, metrics: MetricsWrapper, normalize_features: bool = True, + num_classes: int | None = None, + tabular_dim: int | None = None, ) -> None: """Implementation of the predictive model with replaceable GEO encoder, and prediction head. @@ -36,13 +38,15 @@ def __init__( :param scheduler: scheduler to use for training :param loss_fn: loss function to use :param metrics: metrics to use for model performance evaluation - :param num_classes: number of target classes - :param tabular_dim: number of tabular features :param normalize_features: if True, apply L2 normalisation to encoder output before the prediction head (default: True) + :param num_classes: number of target classes + :param tabular_dim: number of tabular features """ - super().__init__(trainable_modules, optimizer, scheduler, loss_fn, metrics) + super().__init__( + trainable_modules, optimizer, scheduler, loss_fn, metrics, num_classes, tabular_dim + ) # Geo encoder configuration self.geo_encoder = geo_encoder @@ -55,8 +59,15 @@ def __init__( @override def setup(self, stage: str) -> None: - self.num_classes = self.trainer.datamodule.num_classes - self.tabular_dim = self.trainer.datamodule.tabular_dim + """Updates model based data-bound configurations (through datamodule), This method is + called after trainer is initialized and datamodule is available. + + Otherwise, some configuration variables must be made available + """ + + if self._trainer is not None: + self.num_classes = self.trainer.datamodule.num_classes + self.tabular_dim = self.trainer.datamodule.tabular_dim if stage != "fit": if isinstance(self.trainable_modules, tuple): @@ -67,7 +78,10 @@ def setup(self, stage: str) -> None: print("------------------------") # Freezing requested parts - self.freezer() + if stage in ["inference"]: + self.full_freezer() + else: + self.freezer() def setup_encoders_adapters(self): """Set up encoders and missing adapters/projectors.""" diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 2cc1048..71737a7 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -1,4 +1,3 @@ -from io import text_encoding from typing import Dict, Tuple, override import torch @@ -32,6 +31,8 @@ def __init__( prediction_head: BasePredictionHead | None = None, ks: list[int] | None = [5, 10, 15], match_to_geo: bool = True, + num_classes: int | None = None, + tabular_dim: int | None = None, ) -> None: """Implementation of contrastive text-eo modality alignment model. @@ -42,14 +43,16 @@ def __init__( :param loss_fn: loss function to use (contrastive) :param trainable_modules: list of modules to train (parts/modules or modules, modules) :param metrics: metrics to use for model performance evaluation - :param num_classes: number of target classes - :param tabular_dim: number of tabular features :param prediction_head: prediction head :param ks: list of ks :param match_to_geo: whether to match dimensions of text encoder to geo_encoder or visa- versa + :param num_classes: number of target classes + :param tabular_dim: number of tabular features """ - super().__init__(trainable_modules, optimizer, scheduler, loss_fn, metrics) + super().__init__( + trainable_modules, optimizer, scheduler, loss_fn, metrics, num_classes, tabular_dim + ) # Metrics self.ks = ks @@ -64,31 +67,33 @@ def __init__( self.prediction_head = prediction_head @override - def setup(self, stage: str) -> None: - self.num_classes = self.trainer.datamodule.num_classes - self.tabular_dim = self.trainer.datamodule.tabular_dim + def setup(self, stage: str = "fit") -> None: + """Updates model based data-bound configurations (through datamodule), This method is + called after trainer is initialized and datamodule is available. + + Otherwise, some configuration variables must be made available + """ + + if self._trainer is not None: + self.num_classes = self.trainer.datamodule.num_classes + self.tabular_dim = self.trainer.datamodule.tabular_dim # Set up encoders and missing adapters/projectors print("-------Model------------") self.setup_encoders_adapters() print("------------------------") - # Freeze requested parts - self.freezer() + # Freeze not requested parts + if stage in ["inference"]: + self.full_freezer() + else: + self.freezer() - # Configure contrastive retrieval evaluation - self.setup_retrieval_evaluation() + # Configure contrastive retrieval evaluation + self.setup_retrieval_evaluation() def setup_encoders_adapters(self): """Set up encoders and missing adapters/projectors.""" - # We don't use tabular encoders for wrapping - # if ( - # isinstance(self.geo_encoder, MultiModalEncoder) - # and self.geo_encoder.use_tabular - # and not self.geo_encoder._tabular_ready - # ): - # self.geo_encoder.build_tabular_branch(self.tabular_dim) - # Setup encoders that need data-depended configurations new_modules = [f"geo_encoder.{i}" for i in self.geo_encoder.setup()] self.trainable_modules.extend(new_modules) @@ -109,11 +114,6 @@ def setup_encoders_adapters(self): ) self.prediction_head.setup() - # # Unify dtypes -> moving to data part, rather than changing parameter type - # if self.geo_encoder.dtype != self.text_encoder.dtype: - # self.geo_encoder = self.geo_encoder.to(self.text_encoder.dtype) - # print(f"Geo encoder dtype changed to {self.geo_encoder.dtype}") - def setup_retrieval_evaluation(self): self.concept_configs = self.trainer.datamodule.concept_configs self.concepts = [c["concept_caption"] for c in self.concept_configs] From 561411cf48a28f2544598e3f41b3d9f8bffd2c19 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:13:13 +0200 Subject: [PATCH 18/78] Introduce inference model --- configs/inference.yaml | 25 ++++ src/inference.py | 46 ++++++ src/models/inference_model.py | 262 ++++++++++++++++++++++++++++++++++ 3 files changed, 333 insertions(+) create mode 100644 configs/inference.yaml create mode 100644 src/inference.py create mode 100644 src/models/inference_model.py diff --git a/configs/inference.yaml b/configs/inference.yaml new file mode 100644 index 0000000..45f5bcc --- /dev/null +++ b/configs/inference.yaml @@ -0,0 +1,25 @@ +defaults: + - _self_ + - paths: ${oc.env:STORAGE_MODE,local} + - extras: default + - hydra: default + +# task name, determines output directory path +task_name: "inference" +tags: ["dev"] + +# If set, inference.py will load this merged checkpoint directly. +#inference_ckpt_path: ${paths.log_dir}${task_name}/2026-04-07_12-56-16/inference_model.ckpt + +# If `inference_ckpt_path` is not set, stitch the inference model from: +# - predictive ckpt (provides prediction_head weights) +# - alignment ckpt (provides text_encoder weights + geo_encoder) +predictive_ckpt_path: ${paths.log_dir}train/runs/2026-04-02_15-54-53/checkpoints/epoch_000.ckpt +alignment_ckpt_path: ${paths.log_dir}train/runs/2026-04-02_15-40-03/checkpoints/epoch_000.ckpt + +# If set, inference.py will save a merged inference checkpoint you can reload with +# `inference_ckpt_path`. +save_inference_ckpt_path: ${paths.log_dir}${task_name}/${now:%Y-%m-%d}_${now:%H-%M-%S}/inference_model.ckpt +#save_inference_ckpt_path: null + +training_order: ["alignment_model", "prediction_model"] diff --git a/src/inference.py b/src/inference.py new file mode 100644 index 0000000..dc15b50 --- /dev/null +++ b/src/inference.py @@ -0,0 +1,46 @@ +from typing import Optional + +import hydra +import rootutils +from dotenv import load_dotenv +from omegaconf import DictConfig + +from src.models.inference_model import load_inference_model, merge_inference_model +from src.utils import extras + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +load_dotenv() + +# Disable tokenizers parallelism to avoid warnings when using multiprocessing +import os + +if os.environ.get("TOKENIZERS_PARALLELISM") is None: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="inference.yaml") +def main(cfg: DictConfig) -> Optional[float]: + """Main entry point for training. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Optional[float] with optimized metric value. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + # If a merged inference ckpt is provided, just load it. + inference_ckpt_path = cfg.get("inference_ckpt_path") + if inference_ckpt_path: + model = load_inference_model(inference_ckpt_path) + # Otherwise merge model from two checkpoints + else: + model = merge_inference_model(cfg, save_ckpt=True) + + # TODO: do what you need with the inference model + + return + + +if __name__ == "__main__": + main() diff --git a/src/models/inference_model.py b/src/models/inference_model.py new file mode 100644 index 0000000..93447be --- /dev/null +++ b/src/models/inference_model.py @@ -0,0 +1,262 @@ +import os +from typing import Dict, Tuple, override + +import hydra +import torch +import torch.nn.functional as F + +from src.models.base_model import BaseModel +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder +from src.models.components.pred_heads.linear_pred_head import BasePredictionHead +from src.models.components.text_encoders.base_text_encoder import ( + BaseTextEncoder, +) +from src.utils import RankedLogger +from src.utils.errors import FileNotSpecified + +log = RankedLogger(__name__, rank_zero_only=True) + + +class InferenceModel(BaseModel): + def __init__( + self, + geo_encoder: BaseGeoEncoder, + text_encoder: BaseTextEncoder, + prediction_head: BasePredictionHead, + num_classes: int | None = None, + tabular_dim: int | None = None, + match_to_geo: bool = True, + **kwargs, + ) -> None: + super().__init__( + [], None, None, None, None, num_classes=num_classes, tabular_dim=tabular_dim + ) + + # Encoders configuration + self.geo_encoder = geo_encoder + self.text_encoder = text_encoder + + self.match_to_geo = match_to_geo + + # Prediction head + self.prediction_head = prediction_head + + @override + def setup(self, stage: str) -> None: + # During inference we need to ensure: + # - geo_encoder is fully initialized (sets geo_encoder.output_dim) + # - text/geo dims match (possibly via text_encoder.extra_projector) + # - prediction_head.net is created with correct (input_dim, output_dim) + if stage != "inference": + return + + # Configure geo encoder and its output_dim only if it wasn't already set up. + # (In the normal "stitch-from-ckpts" flow, the modules are already initialized.) + if getattr(self.geo_encoder, "output_dim", None) is None or ( + hasattr(self.geo_encoder, "geo_encoder") + and getattr(self.geo_encoder, "geo_encoder") is None + ): + self.geo_encoder.setup() + + # Configure optional extra projection so text embeddings match geo embeddings. + # Note: current codebase applies extra projector on text encoders (not on geo encoders) + # during forward, so `match_to_geo` is expected to be True. + if self.text_encoder.output_dim != self.geo_encoder.output_dim: + if not self.match_to_geo: + raise ValueError( + "match_to_geo=False is not supported for inference: geo extra projector " + "is not applied in geo encoder forward passes in this codebase." + ) + # If extra_projector already exists but output dims still mismatch, we recreate it. + # Otherwise, avoid overwriting weights unnecessarily. + if ( + getattr(self.text_encoder, "extra_projector", None) is None + or self.text_encoder.output_dim != self.geo_encoder.output_dim + ): + self.text_encoder.add_projector(projected_dim=self.geo_encoder.output_dim) + + # Configure prediction head (it creates `prediction_head.net` in setup()). + if self.prediction_head.net is None: + if self.num_classes is None: + raise ValueError( + "InferenceModel requires `num_classes` to build the prediction head." + ) + self.prediction_head.set_dim( + input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes + ) + self.prediction_head.setup() + + # Freeze everything for pure inference. + self.full_freezer() + + @override + def _step( + self, + batch: Dict[str, torch.Tensor], + mode: str = "train", + ) -> torch.Tensor: + """Step forward computation of the model.""" + pass + + @override + def forward( + self, + batch: Dict[str, torch.Tensor], + mode: str = "train", + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Model forward logic.""" + + # Embed modalities + geo_feats = self.geo_encoder(batch) + text_feats = self.text_encoder(batch, mode) + pred_feats = self.prediction_head(geo_feats) + + # Change dtype of geo data if it does not match text dtype + if geo_feats.dtype != text_feats.dtype: + geo_feats = geo_feats.to(text_feats.dtype) + return pred_feats, geo_feats, text_feats + + def concept_similarities(self, geo_embeds, concept=None) -> torch.Tensor: + # Get concept embeddings + if concept is not None: + # If only one concept is provided + if isinstance(concept, str): + concept = [concept] + with torch.no_grad(): + concept_embeds = self.text_encoder({"text": concept}, mode="train") + + elif self.concept_embeds is not None: + concept_embeds = self.concept_embeds + else: + with torch.no_grad(): + concept_embeds = self.text_encoder({"text": self.concepts}, mode="train") + + # Similarity + geo_embeds = F.normalize(geo_embeds, dim=1) + concept_embeds = F.normalize(concept_embeds, dim=1) + similarity_matrix = concept_embeds @ geo_embeds.T + + return similarity_matrix + + +def _is_prefix_trained(trainable_modules: list[str], prefix: str) -> bool: + """True if any trainable module starts with `prefix` (before dot).""" + return any(m.split(".")[0] == prefix for m in trainable_modules) + + +def _group_keys(keys: list[str]) -> dict[str, list[str]]: + """Groups module names (keys)""" + grouped: dict[str, list[str]] = {} + for k in keys: + top = k.split(".", 1)[0] if "." in k else k + grouped.setdefault(top, []).append(k) + return grouped + + +def _log_load_result(tag: str, result) -> None: + """Log missing/unexpected keys from `load_state_dict`.""" + missing_keys, unexpected_keys = result + if missing_keys: + grouped = _group_keys(list(missing_keys)) + summary = {k: len(v) for k, v in grouped.items()} + log.warning(f"[{tag}] Missing keys: {summary}") + if unexpected_keys: + grouped = _group_keys(list(unexpected_keys)) + summary = {k: len(v) for k, v in grouped.items()} + log.warning(f"[{tag}] Unexpected keys: {summary}") + + +def load_inference_model(inference_ckpt_path: str) -> InferenceModel: + """Loads inference model from a merged checkpoint. + + :param inference_ckpt_path: path to inference model weights + :return: an InferenceModel with pre-trained weights + """ + inference_ckpt = torch.load(inference_ckpt_path, map_location="cpu", weights_only=False) + model = hydra.utils.instantiate(inference_ckpt["hyper_parameters"]) + model.setup("inference") + res = model.load_state_dict(inference_ckpt["state_dict"], strict=False) + _log_load_result("inference_ckpt", res) + return model + + +def merge_inference_model(cfg, save_ckpt=False) -> InferenceModel | None: + """Configures the inference model from the predictive + alignment checkpoints. + + :param cfg: A DictConfig configuration composed by Hydra. + :param save_ckpt: Whether to save the model or not. + :return: an InferenceModel with pre-trained weights + """ + + # Stitch the inference model from the predictive + alignment checkpoints. + pred_ckpt_path = cfg.get("predictive_ckpt_path") or FileNotSpecified( + 'You must specify predictive model weight path as "predictive_ckpt_path"' + ) + align_ckpt_path = cfg.get("alignment_ckpt_path") or FileNotSpecified( + 'You must specify alignment model weight path as "alignment_ckpt_path"' + ) + # TODO: remove dataset saving into the checkpoint + pred_ckpt = torch.load(pred_ckpt_path, map_location="cpu", weights_only=False) + align_ckpt = torch.load(align_ckpt_path, map_location="cpu", weights_only=False) + + # Sanity check: ensure geo encoder configs match. + if pred_ckpt["hyper_parameters"].get("geo_encoder") != align_ckpt["hyper_parameters"].get( + "geo_encoder" + ): + log.warning("Geo encoder configs differ between checkpoints; results may be invalid.") + if input("Do you want to proceed? y/n").lower() == "n": + return None + + pred_trainable_modules = [ + pred_ckpt["hyper_parameters"].get("trainable_modules", [])[0] + ] # TODO fix + align_trainable_modules = align_ckpt["hyper_parameters"].get("trainable_modules", []) + + geo_pred_encoder_trained = _is_prefix_trained(pred_trainable_modules, "geo_encoder") + geo_align_encoder_trained = _is_prefix_trained(align_trainable_modules, "geo_encoder") + + if geo_pred_encoder_trained and geo_align_encoder_trained: + raise ValueError("Models are not aligned: both checkpoints trained geo_encoder.") + + # Instantiate InferenceModel via hydra, using alignment encoder configs with prediction model head configs + inference_hparams = align_ckpt["hyper_parameters"] + inference_hparams.update( + { + "_target_": "src.models.inference_model.InferenceModel", + "prediction_head": pred_ckpt["hyper_parameters"].get("prediction_head"), + "num_classes": pred_ckpt["hyper_parameters"].get("num_classes"), + } + ) + inference_hparams["text_encoder"]["hf_cache_dir"] = os.path.join( + cfg.paths.cache_dir, "huggingface" + ) + + model: InferenceModel = hydra.utils.instantiate(inference_hparams) + model.setup("inference") + + # Load alignment weights first (text encoder + geo encoder). + res = model.load_state_dict(align_ckpt["state_dict"], strict=False) + _log_load_result("alignment_ckpt", res) + + # Load prediction head weights from predictive ckpt. + head_state = { + k: v for k, v in pred_ckpt["state_dict"].items() if k.startswith("prediction_head.") + } + res = model.load_state_dict(head_state, strict=False) + _log_load_result("predictive_prediction_head_only", res) + + # Save model + if save_ckpt: + save_path = cfg.get("save_inference_ckpt_path") + if not save_path: + print("Model could not be saved as save_path was not provided") + + # Get `state_dict` + state_dict = model.state_dict() + + # Save + os.makedirs(os.path.dirname(save_path), exist_ok=True) + torch.save({"state_dict": state_dict, "hyper_parameters": inference_hparams}, save_path) + log.info(f"Saved merged inference checkpoint to: {save_path}") + + return model From 4ce2c21bdff6735ec151f20851a3c127082ff356 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:13:37 +0200 Subject: [PATCH 19/78] Add configuration saving into ckpg --- src/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/train.py b/src/train.py index 8348fbb..7a4992c 100644 --- a/src/train.py +++ b/src/train.py @@ -6,7 +6,7 @@ from dotenv import load_dotenv from lightning import Callback, LightningModule, Trainer from lightning.pytorch.loggers import Logger -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from src.data.base_datamodule import BaseDataModule @@ -53,6 +53,10 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: log.info(f"Instantiating model <{cfg.model._target_}>") model: LightningModule = hydra.utils.instantiate(cfg.model) + # Append model hparams from config to be saved in ckpg + raw_model_cfg = OmegaConf.to_container(cfg.model, resolve=True) + model.hparams.update(raw_model_cfg) + log.info("Instantiating callbacks...") callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) From 14709536cc346a8923bc57a0de1378b96f6fedfc Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:17:32 +0200 Subject: [PATCH 20/78] tessera fixes --- src/data_preprocessing/tessera_embeds.py | 86 ++++++++++++++++++------ 1 file changed, 64 insertions(+), 22 deletions(-) diff --git a/src/data_preprocessing/tessera_embeds.py b/src/data_preprocessing/tessera_embeds.py index c2414b8..1e1830d 100644 --- a/src/data_preprocessing/tessera_embeds.py +++ b/src/data_preprocessing/tessera_embeds.py @@ -1,3 +1,4 @@ +import concurrent.futures import math import os import threading @@ -23,6 +24,18 @@ ) +class PartialTileError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class NoTileError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + def reproject_dataset(src_raster: MemoryFile, dst_crs: str) -> MemoryFile: """Reprojects Memory file if it's not in dst_crs. @@ -70,6 +83,7 @@ def get_tessera_embeds( save_dir: str, tile_size: int, tessera_con: GeoTessera | None, + padding: int = 100, ) -> None: """Obtain tessera embedding tile with specified size for a given coordinates. @@ -80,9 +94,11 @@ def get_tessera_embeds( :param save_dir: data directory to save embeddings :param tile_size: tile size in pixels :param tessera_con: GeoTessera instance + :param padding: how many meters to pad initial bbox, fixes some inconsistencies when mosaicing :return: None """ + # Skip if tile exists embed_tile_name = os.path.join(save_dir, f"tessera_{name_loc}.npy") if os.path.exists(embed_tile_name): return @@ -92,7 +108,7 @@ def get_tessera_embeds( lon_utm, lat_utm = point_reprojection(lon, lat, "EPSG:4326", utm_crs) # Bounding box - radius = math.ceil(tile_size / 2) + 10 + radius = math.ceil(tile_size / 2) + padding bbox = create_bbox_with_radius(lon, lat, radius=radius, utm_crs=utm_crs, return_wgs=True) # Request to tessera @@ -126,13 +142,10 @@ def get_tessera_embeds( if reproject_memfile: memfiles.append(reproject_memfile) - if not tiles: - print( - f"No TESSERA tiles found for {name_loc} at ({lon:.4f}, {lat:.4f}) year={year}. Skipping." - ) + if len(tiles) == 0: for mf in memfiles: mf.close() - return + raise NoTileError(f"No tiles found for {name_loc}") # if no tiles, add to skipped.txt mosaic, mosaic_transform = merge(tiles) mosaic = mosaic.transpose(1, 2, 0) @@ -143,20 +156,29 @@ def get_tessera_embeds( mf.close() # Crop patch tile - col, row = crs_to_pixel_coords(lon_utm, lat_utm, mosaic_transform) + c, r = crs_to_pixel_coords(lon_utm, lat_utm, mosaic_transform) half = tile_size // 2 - row_min = row - half - row_max = row + tile_size - half # tile_size - half ensures correct size for odd tile_size - col_min = col - half - col_max = col + tile_size - half + row_min = r - half + row_max = r + half + col_min = c - half + col_max = c + half + + if row_min < 0 or row_max < 0 or col_min < 0 or col_max < 0: + # retry with bigger padding + if padding > 500: + raise NoTileError(f"Padding {padding} > 500") + get_tessera_embeds( + lon, lat, name_loc, year, save_dir, tile_size, tessera_con, padding=padding + 100 + ) + crop = mosaic[row_min:row_max, col_min:col_max, :] + if not crop.shape == (tile_size, tile_size, 128): + if crop.min() == 0.0 and crop.max() == 0.0: + raise NoTileError(f"No tiles found for {name_loc}") + raise PartialTileError(f"Crop {name_loc}, size is {crop.shape}") - if crop.shape[0] != tile_size or crop.shape[1] != tile_size: - print( - f"Unexpected crop shape {crop.shape} for {name_loc} " - f"(expected {tile_size}x{tile_size}). Skipping." - ) - return + if crop.min() == 0.0 and crop.max() == 0.0: + raise NoTileError(f"Crop {name_loc} has embeddings of 0.0s with tiles: {tiles_to_fetch}") # Save array os.makedirs(save_dir, exist_ok=True) @@ -186,6 +208,7 @@ def tessera_from_df( year: int, tile_size: int = 256, cache_dir: str = "temp/", + logs_dir: str = "logs", ) -> None: """Obtains Tessera embeddings from a CSV file for each (lon, lat). @@ -205,8 +228,17 @@ def tessera_from_df( n = len(model_ready_df) for i, row in model_ready_df.iterrows(): print(f"{i}/{n}") - # Get tessera embeds - get_tessera_embeds(row.lon, row.lat, row.name_loc, year, f"{data_dir}/", tile_size, gt) + if i < 238 or i in [1319]: + continue + try: + get_tessera_embeds(row.lon, row.lat, row.name_loc, year, f"{data_dir}/", tile_size, gt) + except Exception as e: + if isinstance(e, NoTileError): + path = os.path.join(logs_dir, "tessera_skipped.txt") + with open(path, "a") as f: + f.write(f"{row.name_loc}\n") + else: + print(f"{row.name_loc} did not get embedded because: {e}") def inspect_np_arr_as_tiff( @@ -270,10 +302,20 @@ def inspect_np_arr_as_tiff( if __name__ == "__main__": - os.chdir("../..") + print(os.getcwd()) + + # df = pd.read_csv("data/heat_guatemala/model_ready_heat_guatemala.csv") + df = pd.read_csv("/lustre/backup/SHARED/AIN/aether/data/s2bms/model_ready_s2bms.csv") + # df.sort_values(by="name_loc", inplace=True, ascending=False) + with open(os.path.join("logs", "tessera_skipped.txt")) as f: + skipped = set(f.read().splitlines()) - df = pd.read_csv("data/heat_guatemala/model_ready_heat_guatemala.csv") + df = df[~df.name_loc.isin(skipped)] tessera_from_df( - df, "data/heat_guatemala/eo/tessera_2024", year=2024, tile_size=10, cache_dir="data/cache" + df, + "/lustre/backup/SHARED/AIN/aether/data/s2bms/eo/tessera", + year=2024, + tile_size=256, + cache_dir="/lustre/backup/SHARED/AIN/aether/data/cache", ) From a4a94c5ba29a1c048fe382b5f84309ba91a2a183 Mon Sep 17 00:00:00 2001 From: Thijs van der Plas Date: Fri, 6 Mar 2026 13:46:07 +0100 Subject: [PATCH 21/78] minor gee update in case outside of radius_max, return radius_max --- src/data_preprocessing/gee_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/data_preprocessing/gee_utils.py b/src/data_preprocessing/gee_utils.py index def17a9..f1f73bf 100644 --- a/src/data_preprocessing/gee_utils.py +++ b/src/data_preprocessing/gee_utils.py @@ -250,8 +250,8 @@ def get_distance_to_road_within_aoi(aoi, cell_size=30, radius_max=5000): reducer=ee.Reducer.mean(), geometry=aoi, scale=cell_size, maxPixels=1e9 ) return { - "maxdist_road": int(max_distance.get("distance").getInfo()), - "meandist_road": int(mean_distance.get("distance").getInfo()), + "maxdist_road": int(max_distance.get("distance").getInfo() or radius_max), + "meandist_road": int(mean_distance.get("distance").getInfo() or radius_max), } From 51aed7d719d7489d98f59fbe87baa4f230d27f80 Mon Sep 17 00:00:00 2001 From: Thijs van der Plas Date: Wed, 11 Mar 2026 14:52:11 +0100 Subject: [PATCH 22/78] concept captions manually picked relevant captions for biodiv UC, inspecting histograms to look for sensible thresholds. --- data/s2bms/concept_captions/v2.json | 103 ++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 data/s2bms/concept_captions/v2.json diff --git a/data/s2bms/concept_captions/v2.json b/data/s2bms/concept_captions/v2.json new file mode 100644 index 0000000..9b2bdd5 --- /dev/null +++ b/data/s2bms/concept_captions/v2.json @@ -0,0 +1,103 @@ +[ + { + "concept_caption": "Densely populated area with many houses", + "is_max": true, + "theta_k": 0.3, + "col": "aux_corine_frac_11" + }, + { + "concept_caption": "Very sparsely populated area with few houses", + "is_max": false, + "theta_k": 0.05, + "col": "aux_corine_frac_11" + },{ + "concept_caption": "Area with infrastructure such as roads, railways, airport, ports and heavy industry.", + "is_max": true, + "theta_k": 0.1, + "col": "aux_corine_frac_12" + }, + { + "concept_caption": "Arable land with crops for agriculture", + "is_max": true, + "theta_k": 0.65, + "col": "aux_corine_frac_21" + }, + { + "concept_caption": "Pasture fields with grass for grazing animals", + "is_max": true, + "theta_k": 0.6, + "col": "aux_corine_frac_231" + }, + { + "concept_caption": "Agricultural land used for crops, pasture or mixed farming", + "is_max": true, + "theta_k": 0.05, + "col": "aux_corine_frac_24" + }, + { + "concept_caption": "Forested area with many trees", + "is_max": true, + "theta_k": 0.25, + "col": "aux_corine_frac_31" + }, + { + "concept_caption": "Scrub area with trees, shrub, moors.", + "is_max": true, + "theta_k": 0.2, + "col": "aux_corine_frac_32" + }, + { + "concept_caption": "Moorlands and heathlands with low vegetation", + "is_max": true, + "theta_k": 0.2, + "col": "aux_corine_frac_322" + }, + { + "concept_caption": "Wetlands such as marshes, swamps, mudflats and bogs.", + "is_max": true, + "theta_k": 0.2, + "col": "aux_corine_frac_4" + }, + { + "concept_caption": "Peat bogs", + "is_max": true, + "theta_k": 0.2, + "col": "aux_corine_frac_412" + }, + { + "concept_caption": "Water bodies such as lakes, rivers and sea", + "is_max": true, + "theta_k": 0.2, + "col": "aux_corine_frac_5" + }, + { + "concept_caption": "Warm area with high summer temperatures", + "is_max": true, + "theta_k": 22, + "col": "aux_bioclim_05" + }, + { + "concept_caption": "Cold area with low winter temperatures", + "is_max": false, + "theta_k": 0, + "col": "aux_bioclim_06" + }, + { + "concept_caption": "Wet area with a lot of rainfall", + "is_max": true, + "theta_k": 950, + "col": "aux_bioclim_12" + }, + { + "concept_caption": "Remote area far from roads and infrastructure", + "is_max": true, + "theta_k": 1500, + "col": "aux_meandist_road" + }, + { + "concept_caption": "Densely populated area with many houses", + "is_max": true, + "theta_k": 1500, + "col": "aux_pop_density" + } +] From 4f0ed1df449c3968af79e6f315aaae5a2374dcd3 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Wed, 11 Mar 2026 10:45:20 +0100 Subject: [PATCH 23/78] Makes DBScan clustering more efficient and much faster. --- src/data/base_datamodule.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/data/base_datamodule.py b/src/data/base_datamodule.py index 201c671..1ef0a48 100644 --- a/src/data/base_datamodule.py +++ b/src/data/base_datamodule.py @@ -1,12 +1,12 @@ import copy import os +import time from functools import partial from typing import Any, Dict, List, Tuple import numpy as np import pandas as pd import torch -from geopy.distance import distance as geodist # avoid naming confusion from lightning import LightningDataModule from sklearn.cluster import DBSCAN from sklearn.model_selection import GroupShuffleSplit @@ -121,20 +121,27 @@ def split_data(self) -> None: } elif self.hparams.split_mode == "spatial_clusters": - print("Splitting dataset using spatial clusters. This can take a while...") - coords = np.array([self.dataset.df.lat, self.dataset.df.lon]).T - if len(coords) > 2000: - print( - "Warning: DBSCAN clustering on more than 2000 samples may be slow. Maybe set n_jobs in DBScan?" - ) - # 4000 m distance between points. Use geodist to calculate true distance. min_dist = self.hparams.spatial_split_distance_m + coords = np.array([self.dataset.df.lat, self.dataset.df.lon]).T + n = len(coords) + print( + f"Splitting {n} samples into spatial clusters " + f"(eps={min_dist / 1000:.1f} km, haversine, n_jobs=-1)..." + ) + # Convert (lat, lon) degrees to radians for sklearn's haversine metric. + # haversine returns arc length on the unit sphere, so eps must be in radians. + _EARTH_RADIUS_M = 6_371_000 + coords_rad = np.radians(coords) + eps_rad = min_dist / _EARTH_RADIUS_M + t0 = time.time() clustering = DBSCAN( - eps=min_dist, - metric=lambda u, v: geodist(u, v).meters, + eps=eps_rad, + metric="haversine", + algorithm="ball_tree", min_samples=2, - ).fit(coords) - print("Clustering done. Creating splits and saving.") + n_jobs=-1, + ).fit(coords_rad) + print(f"DBSCAN done in {time.time() - t0:.1f}s. Creating splits...") # Non-clustered points are labeled -1. Change to new cluster label. clusters = copy.deepcopy(clustering.labels_) new_cl = np.max(clusters) + 1 From 6911022716463d7c6a9eaec1ed55ee5e179aeb46 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Wed, 11 Mar 2026 10:53:12 +0100 Subject: [PATCH 24/78] Crop Yield use case: spatial splitting --- .../yield_africa_spatial_splits.py | 326 ++++++++++++++++++ 1 file changed, 326 insertions(+) create mode 100644 src/data_preprocessing/yield_africa_spatial_splits.py diff --git a/src/data_preprocessing/yield_africa_spatial_splits.py b/src/data_preprocessing/yield_africa_spatial_splits.py new file mode 100644 index 0000000..4c9aa98 --- /dev/null +++ b/src/data_preprocessing/yield_africa_spatial_splits.py @@ -0,0 +1,326 @@ +"""Generate spatial-cluster split files for the yield_africa dataset. + +Location: src/data_preprocessing/yield_africa_spatial_splits.py + +Uses DBSCAN with a haversine distance metric to group nearby field locations +into clusters, then assigns whole clusters to train/val/test so that no +geographically close points straddle a split boundary. + +One `.pth` file is written per distance threshold to +`{data_dir}/yield_africa/splits/split_spatial_{distance_km}km.pth`. + +Split layout +------------ +- train : ~70 % of records (cluster-aligned) +- val : ~15 % of records (cluster-aligned) +- test : ~15 % of records (cluster-aligned) + +Proportions are approximate because whole clusters are kept intact. + +The files are consumed by BaseDataModule when `split_mode: from_file` and +`saved_split_file_name: split_spatial_{distance_km}km.pth`. + +Usage +----- + # Generate the default set of splits (10 km, 25 km, 50 km) + python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir data/ + + # Generate a single split at a specific distance + python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir data/ --distance_km 25 + + # Generate multiple distances in one run + python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir data/ --distance_km 10 25 50 + +Notes +----- +- DBSCAN uses sklearn's built-in haversine metric with a BallTree spatial index + and n_jobs=-1, which is significantly faster than a Python geodesic lambda. + Haversine vs. true geodesic error is < 0.1% at distances up to ~100 km. +- `min_samples=2` means a pair of fields within `distance_km` of each other + forms a cluster; isolated fields each become their own singleton cluster. +- All clusters are kept intact across the split boundary, so the test set + contains no locations geographically close to any training location. +""" + +import argparse +import copy +import logging +import time +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from sklearn.cluster import DBSCAN + +log = logging.getLogger(__name__) + +DATASET_NAME = "yield_africa" +MODEL_READY_CSV = f"model_ready_{DATASET_NAME}.csv" + +# Default distances to generate when no --distance_km is supplied. +DEFAULT_DISTANCES_KM = [10, 25, 50] + +# Split proportions (must sum to 1.0). +TRAIN_FRAC = 0.70 +VAL_FRAC = 0.15 +TEST_FRAC = 0.15 + +# Fixed random seed for GroupShuffleSplit. +SEED = 12345 + + +def make_spatial_split( + df: pd.DataFrame, + distance_m: int, + train_val_test_split: tuple[float, float, float] = (TRAIN_FRAC, VAL_FRAC, TEST_FRAC), + seed: int = SEED, +) -> dict: + """Return a split-indices dict using DBSCAN spatial clustering. + + :param df: full model-ready dataframe (must contain 'lat', 'lon', 'name_loc') + :param distance_m: DBSCAN eps in metres — pairs of fields closer than this + value are assigned to the same cluster + :param train_val_test_split: (train, val, test) proportions, must sum to 1.0 + :param seed: random seed for GroupShuffleSplit + :return: dict with 'train_indices', 'val_indices', 'test_indices' as + pd.Series of name_loc strings, plus 'clusters' as a numpy array of + cluster labels (same length as df) + """ + # Deduplicate to unique (lat, lon) locations before clustering. + # yield_africa has ~9 rows per location (one per year); running DBSCAN on all + # rows produces giant clusters whose row counts are unequal, causing + # GroupShuffleSplit (which splits by cluster count) to produce badly skewed + # train/val/test proportions. Clustering unique locations and propagating + # the split back to all rows fixes this. + unique_locs = df.drop_duplicates(subset=["lat", "lon"]).reset_index(drop=True) + n_unique = len(unique_locs) + n_total = len(df) + if n_unique < n_total: + print( + f" Deduplicating: {n_unique} unique locations from {n_total} rows " + f"(~{n_total / n_unique:.1f} rows/location)." + ) + + # Convert (lat, lon) degrees to radians for sklearn's haversine metric. + # haversine returns arc length on the unit sphere, so eps must be in radians. + # Error vs. true geodesic is < 0.1% at distances up to ~100 km. + _EARTH_RADIUS_M = 6_371_000 + coords_rad = np.radians(np.array([unique_locs["lat"].values, unique_locs["lon"].values]).T) + eps_rad = distance_m / _EARTH_RADIUS_M + + print( + f" Running DBSCAN (eps={distance_m / 1000:.1f} km, haversine, " + f"n={n_unique} locations, n_jobs=-1)..." + ) + t0 = time.time() + clustering = DBSCAN( + eps=eps_rad, + metric="haversine", + algorithm="ball_tree", + min_samples=2, + n_jobs=-1, + ).fit(coords_rad) + print(f" DBSCAN done in {time.time() - t0:.1f}s.") + + # Noise points (label -1) each become their own unique cluster so that + # GroupShuffleSplit can assign them individually to a split partition. + clusters = copy.deepcopy(clustering.labels_) + next_label = int(np.max(clusters)) + 1 + for i, label in enumerate(clusters): + if label == -1: + clusters[i] = next_label + next_label += 1 + + n_clusters = len(np.unique(clusters)) + n_noise = int(np.sum(clustering.labels_ == -1)) + print(f" Clustering done: {n_clusters} location clusters ({n_noise} singleton noise points).") + + train_prop, val_prop, test_prop = train_val_test_split + + # Greedy size-aware cluster assignment. + # + # GroupShuffleSplit splits by cluster *count*, not by sample count. When the + # cluster size distribution is heavily skewed (a few mega-clusters + many + # tiny 2-location clusters), this produces badly imbalanced splits. + # + # Instead: shuffle clusters for randomness, sort by size descending, then + # assign each cluster to whichever split is furthest below its sample-count + # target. Each cluster goes to exactly one split, so there is no overlap. + rng = np.random.default_rng(seed) + unique_clusters, cluster_sizes = np.unique(clusters, return_counts=True) + + # Shuffle first so ties are broken randomly, then sort by descending size. + shuffle_order = rng.permutation(len(unique_clusters)) + unique_clusters = unique_clusters[shuffle_order] + cluster_sizes = cluster_sizes[shuffle_order] + size_order = np.argsort(-cluster_sizes) + unique_clusters = unique_clusters[size_order] + cluster_sizes = cluster_sizes[size_order] + + target_train = n_unique * train_prop + target_val = n_unique * val_prop + target_test = n_unique * test_prop + train_clusters, val_clusters, test_clusters = [], [], [] + count_train, count_val, count_test = 0, 0, 0 + + for cluster_id, size in zip(unique_clusters, cluster_sizes): + deficit_train = target_train - count_train + deficit_val = target_val - count_val + deficit_test = target_test - count_test + if deficit_train >= deficit_val and deficit_train >= deficit_test: + train_clusters.append(cluster_id) + count_train += size + elif deficit_val >= deficit_test: + val_clusters.append(cluster_id) + count_val += size + else: + test_clusters.append(cluster_id) + count_test += size + + train_loc_mask = np.isin(clusters, train_clusters) + val_loc_mask = np.isin(clusters, val_clusters) + test_loc_mask = np.isin(clusters, test_clusters) + + # Sanity checks: every location assigned, no cluster in multiple splits. + assert train_loc_mask.sum() + val_loc_mask.sum() + test_loc_mask.sum() == n_unique + assert len(set(train_clusters) & set(val_clusters)) == 0 + assert len(set(train_clusters) & set(test_clusters)) == 0 + assert len(set(val_clusters) & set(test_clusters)) == 0 + + print( + f" Split (locations): train={train_loc_mask.sum()}, " + f"val={val_loc_mask.sum()}, test={test_loc_mask.sum()}" + ) + + # Propagate location-level split assignments back to all rows by (lat, lon). + train_latlon = set( + zip(unique_locs.loc[train_loc_mask, "lat"], unique_locs.loc[train_loc_mask, "lon"]) + ) + val_latlon = set( + zip(unique_locs.loc[val_loc_mask, "lat"], unique_locs.loc[val_loc_mask, "lon"]) + ) + test_latlon = set( + zip(unique_locs.loc[test_loc_mask, "lat"], unique_locs.loc[test_loc_mask, "lon"]) + ) + row_latlon = list(zip(df["lat"], df["lon"])) + train_mask = np.array([ll in train_latlon for ll in row_latlon]) + val_mask = np.array([ll in val_latlon for ll in row_latlon]) + test_mask = np.array([ll in test_latlon for ll in row_latlon]) + + assert train_mask.sum() + val_mask.sum() + test_mask.sum() == n_total, ( + "Not all rows were assigned to a split — check for (lat, lon) values that " + "don't match any unique location after deduplication." + ) + + name_locs = df["name_loc"].reset_index(drop=True) + return { + "train_indices": name_locs[train_mask].reset_index(drop=True), + "val_indices": name_locs[val_mask].reset_index(drop=True), + "test_indices": name_locs[test_mask].reset_index(drop=True), + "clusters": clusters, + } + + +def generate_splits( + data_dir: str, + distances_km: list[int] | None = None, + seed: int = SEED, +) -> None: + """Generate and save spatial-cluster split files for the requested distances. + + :param data_dir: root data directory (same as `paths.data_dir` in configs) + :param distances_km: list of DBSCAN cluster distances in kilometres; None + uses DEFAULT_DISTANCES_KM + :param seed: random seed for GroupShuffleSplit + """ + if distances_km is None: + distances_km = DEFAULT_DISTANCES_KM + + dataset_dir = Path(data_dir) / DATASET_NAME + csv_path = dataset_dir / MODEL_READY_CSV + splits_dir = dataset_dir / "splits" + + if not csv_path.exists(): + raise FileNotFoundError(f"Model-ready CSV not found: {csv_path}") + + splits_dir.mkdir(parents=True, exist_ok=True) + + df = pd.read_csv(csv_path) + for col in ("lat", "lon", "name_loc"): + if col not in df.columns: + raise ValueError(f"CSV must contain a '{col}' column") + + print(f"Loaded {len(df)} rows from {csv_path}") + + for dist_km in distances_km: + dist_m = dist_km * 1000 + print(f"\nGenerating spatial split at {dist_km} km ({dist_m} m)...") + + split = make_spatial_split(df, distance_m=dist_m, seed=seed) + n_train = len(split["train_indices"]) + n_val = len(split["val_indices"]) + n_test = len(split["test_indices"]) + + out_name = f"split_spatial_{dist_km}km.pth" + out_path = splits_dir / out_name + torch.save(split, out_path) + + print( + f" Saved {out_name} " + f"(train={n_train}, val={n_val}, test={n_test}, " + f"total={n_train + n_val + n_test}/{len(df)})" + ) + log.info( + f" {dist_km}km: train={n_train}, val={n_val}, test={n_test} -> {out_name}" + ) + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(message)s") + + parser = argparse.ArgumentParser( + description="Generate spatial-cluster split files for yield_africa.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--data_dir", + type=str, + default="data/", + help="Root data directory (same as paths.data_dir in configs). Default: data/", + ) + parser.add_argument( + "--distance_km", + type=int, + nargs="+", + default=None, + metavar="KM", + help=( + "Cluster distance threshold(s) in km. " + f"Omit to generate the default set: {DEFAULT_DISTANCES_KM} km." + ), + ) + parser.add_argument( + "--seed", + type=int, + default=SEED, + help=f"Random seed for GroupShuffleSplit. Default: {SEED}", + ) + args = parser.parse_args() + + distances = args.distance_km # None means use defaults + print( + f"Generating spatial splits data_dir={args.data_dir} " + f"distances_km={distances or DEFAULT_DISTANCES_KM} seed={args.seed}" + ) + generate_splits( + data_dir=args.data_dir, + distances_km=distances, + seed=args.seed, + ) + print("\nDone.") + + +if __name__ == "__main__": + main() From 22c9cbb299847f1bbef8014fd80000d4965c7cb3 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Wed, 11 Mar 2026 10:54:17 +0100 Subject: [PATCH 25/78] Crop Yield use case: configs for various experiments --- configs/data/yield_africa_spatial.yaml | 33 ++++++++++++++++ configs/data/yield_africa_tessera_loco.yaml | 39 +++++++++++++++++++ .../data/yield_africa_tessera_spatial.yaml | 39 +++++++++++++++++++ .../experiment/yield_africa_fusion_loco.yaml | 33 ++++++++++++++++ .../yield_africa_fusion_spatial.yaml | 33 ++++++++++++++++ .../yield_africa_tabular_spatial.yaml | 33 ++++++++++++++++ .../yield_africa_tessera_fusion_loco.yaml | 38 ++++++++++++++++++ .../yield_africa_tessera_fusion_spatial.yaml | 38 ++++++++++++++++++ 8 files changed, 286 insertions(+) create mode 100644 configs/data/yield_africa_spatial.yaml create mode 100644 configs/data/yield_africa_tessera_loco.yaml create mode 100644 configs/data/yield_africa_tessera_spatial.yaml create mode 100644 configs/experiment/yield_africa_fusion_loco.yaml create mode 100644 configs/experiment/yield_africa_fusion_spatial.yaml create mode 100644 configs/experiment/yield_africa_tabular_spatial.yaml create mode 100644 configs/experiment/yield_africa_tessera_fusion_loco.yaml create mode 100644 configs/experiment/yield_africa_tessera_fusion_spatial.yaml diff --git a/configs/data/yield_africa_spatial.yaml b/configs/data/yield_africa_spatial.yaml new file mode 100644 index 0000000..9313100 --- /dev/null +++ b/configs/data/yield_africa_spatial.yaml @@ -0,0 +1,33 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.yield_africa_dataset.YieldAfricaDataset + data_dir: ${paths.data_dir} + modalities: + coords: {} + use_target_data: true + use_features: true + use_aux_data: none + seed: ${seed} + cache_dir: ${paths.cache_dir} + # Include all countries and years so the split file determines the partition. + countries: ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + years: [2014, 2016, 2017, 2018, 2019, 2020, 2021, 2023, 2024] + exclude_countries: null + exclude_years: null + +batch_size: 64 +num_workers: 0 +pin_memory: false + +# Spatial-cluster split loaded from a pre-generated file. +# Generate split files first (produces 10 km, 25 km, and 50 km variants): +# python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir +# +# Override saved_split_file_name at the command line to change the cluster distance: +# python src/train.py experiment=yield_africa_tabular_spatial \ +# data.saved_split_file_name=split_spatial_10km.pth +split_mode: "from_file" +saved_split_file_name: "split_spatial_25km.pth" +save_split: false +seed: ${seed} diff --git a/configs/data/yield_africa_tessera_loco.yaml b/configs/data/yield_africa_tessera_loco.yaml new file mode 100644 index 0000000..0be62c3 --- /dev/null +++ b/configs/data/yield_africa_tessera_loco.yaml @@ -0,0 +1,39 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.yield_africa_dataset.YieldAfricaDataset + data_dir: ${paths.data_dir} + modalities: + tessera: + # size must match the tile_size used when running the preprocessing script. + # Default: 9 pixels (set by yield_africa_tessera_preprocess.py --tile_size). + size: 9 + format: npy + # year is intentionally omitted: yield_africa fetches per-record year tiles + # via the preprocessing script rather than a single bulk-year download. + use_target_data: true + use_features: true + use_aux_data: none + seed: ${seed} + cache_dir: ${paths.cache_dir} + # Include all countries and years so the split file determines the partition. + countries: ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + years: [2014, 2016, 2017, 2018, 2019, 2020, 2021, 2023, 2024] + exclude_countries: null + exclude_years: null + +batch_size: 64 +num_workers: 0 +pin_memory: false + +# Leave-one-country-out split loaded from a pre-generated file. +# Generate split files first: +# python src/data_preprocessing/yield_africa_loco_splits.py --data_dir +# +# Override saved_split_file_name at the command line to change the held-out country: +# python src/train.py experiment=yield_africa_tessera_fusion_loco \ +# data.saved_split_file_name=split_loco_RWA.pth +split_mode: "from_file" +saved_split_file_name: "split_loco_KEN.pth" +save_split: false +seed: ${seed} diff --git a/configs/data/yield_africa_tessera_spatial.yaml b/configs/data/yield_africa_tessera_spatial.yaml new file mode 100644 index 0000000..9424801 --- /dev/null +++ b/configs/data/yield_africa_tessera_spatial.yaml @@ -0,0 +1,39 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.yield_africa_dataset.YieldAfricaDataset + data_dir: ${paths.data_dir} + modalities: + tessera: + # size must match the tile_size used when running the preprocessing script. + # Default: 9 pixels (set by yield_africa_tessera_preprocess.py --tile_size). + size: 9 + format: npy + # year is intentionally omitted: yield_africa fetches per-record year tiles + # via the preprocessing script rather than a single bulk-year download. + use_target_data: true + use_features: true + use_aux_data: none + seed: ${seed} + cache_dir: ${paths.cache_dir} + # Include all countries and years so the split file determines the partition. + countries: ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + years: [2014, 2016, 2017, 2018, 2019, 2020, 2021, 2023, 2024] + exclude_countries: null + exclude_years: null + +batch_size: 64 +num_workers: 0 +pin_memory: false + +# Spatial-cluster split loaded from a pre-generated file. +# Generate split files first (produces 10 km, 25 km, and 50 km variants): +# python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir +# +# Override saved_split_file_name at the command line to change the cluster distance: +# python src/train.py experiment=yield_africa_tessera_fusion_spatial \ +# data.saved_split_file_name=split_spatial_10km.pth +split_mode: "from_file" +saved_split_file_name: "split_spatial_25km.pth" +save_split: false +seed: ${seed} diff --git a/configs/experiment/yield_africa_fusion_loco.yaml b/configs/experiment/yield_africa_fusion_loco.yaml new file mode 100644 index 0000000..2540642 --- /dev/null +++ b/configs/experiment/yield_africa_fusion_loco.yaml @@ -0,0 +1,33 @@ +# @package _global_ +# configs/experiment/yield_africa_fusion_loco.yaml +# GeoClip + tabular fusion model evaluated with leave-one-country-out split. +# Default held-out country: KEN (largest, most representative test set). +# +# Generate split files first: +# python src/data_preprocessing/yield_africa_loco_splits.py --data_dir +# +# To evaluate on a different held-out country: +# python src/train.py experiment=yield_africa_fusion_loco \ +# data.saved_split_file_name=split_loco_RWA.pth + +defaults: + - override /model: yield_fusion_reg + - override /data: yield_africa_loco + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "fusion", "regression", "loco"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_fusion_spatial.yaml b/configs/experiment/yield_africa_fusion_spatial.yaml new file mode 100644 index 0000000..98c4221 --- /dev/null +++ b/configs/experiment/yield_africa_fusion_spatial.yaml @@ -0,0 +1,33 @@ +# @package _global_ +# configs/experiment/yield_africa_fusion_spatial.yaml +# GeoClip + tabular fusion model evaluated with a spatial-cluster split. +# Default cluster distance: 25 km (split_spatial_25km.pth). +# +# Generate split files first: +# python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir +# +# To evaluate at a different cluster distance: +# python src/train.py experiment=yield_africa_fusion_spatial \ +# data.saved_split_file_name=split_spatial_10km.pth + +defaults: + - override /model: yield_fusion_reg + - override /data: yield_africa_spatial + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "fusion", "regression", "spatial"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_tabular_spatial.yaml b/configs/experiment/yield_africa_tabular_spatial.yaml new file mode 100644 index 0000000..9c57961 --- /dev/null +++ b/configs/experiment/yield_africa_tabular_spatial.yaml @@ -0,0 +1,33 @@ +# @package _global_ +# configs/experiment/yield_africa_tabular_spatial.yaml +# Tabular-only model evaluated with a spatial-cluster split. +# Default cluster distance: 25 km (split_spatial_25km.pth). +# +# Generate split files first: +# python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir +# +# To evaluate at a different cluster distance: +# python src/train.py experiment=yield_africa_tabular_spatial \ +# data.saved_split_file_name=split_spatial_10km.pth + +defaults: + - override /model: yield_tabular_reg + - override /data: yield_africa_spatial + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tabular_only", "regression", "spatial"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_tessera_fusion_loco.yaml b/configs/experiment/yield_africa_tessera_fusion_loco.yaml new file mode 100644 index 0000000..ee9aa9d --- /dev/null +++ b/configs/experiment/yield_africa_tessera_fusion_loco.yaml @@ -0,0 +1,38 @@ +# @package _global_ +# configs/experiment/yield_africa_tessera_fusion_loco.yaml +# TESSERA + tabular fusion model evaluated with leave-one-country-out split. +# Default held-out country: KEN (largest, most representative test set). +# +# Requires: +# 1. TESSERA tiles pre-fetched: +# python src/data_preprocessing/yield_africa_tessera_preprocess.py --data_dir +# 2. LOCO split files pre-generated: +# python src/data_preprocessing/yield_africa_loco_splits.py --data_dir +# +# To evaluate on a different held-out country: +# python src/train.py experiment=yield_africa_tessera_fusion_loco \ +# data.saved_split_file_name=split_loco_RWA.pth + +defaults: + - override /model: yield_tessera_fusion_reg + - override /data: yield_africa_tessera_loco + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tessera_fusion", "regression", "loco"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + dataset: + use_features: true + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_tessera_fusion_spatial.yaml b/configs/experiment/yield_africa_tessera_fusion_spatial.yaml new file mode 100644 index 0000000..b0eaf9d --- /dev/null +++ b/configs/experiment/yield_africa_tessera_fusion_spatial.yaml @@ -0,0 +1,38 @@ +# @package _global_ +# configs/experiment/yield_africa_tessera_fusion_spatial.yaml +# TESSERA + tabular fusion model evaluated with a spatial-cluster split. +# Default cluster distance: 25 km (split_spatial_25km.pth). +# +# Requires: +# 1. TESSERA tiles pre-fetched: +# python src/data_preprocessing/yield_africa_tessera_preprocess.py --data_dir +# 2. Spatial split files pre-generated: +# python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir +# +# To evaluate at a different cluster distance: +# python src/train.py experiment=yield_africa_tessera_fusion_spatial \ +# data.saved_split_file_name=split_spatial_10km.pth + +defaults: + - override /model: yield_tessera_fusion_reg + - override /data: yield_africa_tessera_spatial + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tessera_fusion", "regression", "spatial"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + dataset: + use_features: true + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" From f3186893d7c58a8c35f6afc93463027e4c888026 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Wed, 11 Mar 2026 11:31:04 +0100 Subject: [PATCH 26/78] Adds RRMSE loss function for crop yield error comparison --- configs/metrics/yield_africa_regression.yaml | 3 +- src/models/components/loss_fns/rrmse_loss.py | 44 ++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 src/models/components/loss_fns/rrmse_loss.py diff --git a/configs/metrics/yield_africa_regression.yaml b/configs/metrics/yield_africa_regression.yaml index 79c441d..7960283 100644 --- a/configs/metrics/yield_africa_regression.yaml +++ b/configs/metrics/yield_africa_regression.yaml @@ -1,7 +1,8 @@ _target_: src.models.components.metrics.metrics_wrapper.MetricsWrapper metrics: - - _target_: src.models.components.loss_fns.mse_loss.MSELoss + - _target_: src.models.components.loss_fns.huber_loss.HuberLoss - _target_: src.models.components.loss_fns.rmse_loss.RMSELoss - _target_: src.models.components.loss_fns.mae_loss.MAELoss + - _target_: src.models.components.loss_fns.rrmse_loss.RRMSELoss - _target_: src.models.components.metrics.r2.RSquared diff --git a/src/models/components/loss_fns/rrmse_loss.py b/src/models/components/loss_fns/rrmse_loss.py new file mode 100644 index 0000000..1720f1b --- /dev/null +++ b/src/models/components/loss_fns/rrmse_loss.py @@ -0,0 +1,44 @@ +from typing import Dict, override + +import torch + +from src.models.components.loss_fns.base_loss_fn import BaseLossFn + + +class RRMSELoss(BaseLossFn): + """Relative Root Mean Squared Error (RRMSE). + + RRMSE = RMSE / mean(|labels|) + + Normalises RMSE by the mean absolute value of the target, giving a + unit-free percentage error. This makes results comparable across crops + and regions with different absolute yield scales (e.g. t/ha ranges + differ significantly between maize in Zambia and rice in Rwanda). + + Returns a fraction (e.g. 0.15 = 15 % error). Multiply by 100 for + percentage when reporting. + """ + + def __init__(self) -> None: + super().__init__() + self.criterion = torch.nn.MSELoss() + self.name = "rrmse_loss" + + @override + def forward( + self, + pred: torch.Tensor, + labels: torch.Tensor | None = None, + batch: Dict[str, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor | Dict[str, torch.Tensor]: + + labels = labels if labels is not None else batch.get("target") + rmse = torch.sqrt(self.criterion(pred, labels)) + mean_abs = torch.mean(torch.abs(labels)) + loss = rmse / (mean_abs + 1e-8) + + if "return_label" in kwargs: + return {self.name: loss} + else: + return loss From da121c4dda4edff69b7e02c617999697bcb0d76b Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:33:27 +0200 Subject: [PATCH 27/78] Crop Yield use case: Adds Fourier harmonics as engineered location features. # Conflicts: # tests/test_yield_africa.py --- src/data/yield_africa_dataset.py | 52 +++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/src/data/yield_africa_dataset.py b/src/data/yield_africa_dataset.py index 8c09801..5456bf2 100644 --- a/src/data/yield_africa_dataset.py +++ b/src/data/yield_africa_dataset.py @@ -12,6 +12,7 @@ import os from typing import Any, Dict, List, override +import numpy as np import pandas as pd import torch @@ -27,6 +28,16 @@ # countries are present after filtering. _ALL_COUNTRIES = ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] +# Study-area bounds used to normalise coordinates before computing Fourier +# harmonics. Normalising to the actual data extent (rather than ±90°/±180°) +# makes the harmonics maximally discriminative within the dataset. +# Latitude : 30°S – 15°N → centre −7.5°, half-range 22.5° +# Longitude : 10°E – 45°E → centre 27.5°, half-range 17.5° +_LAT_CENTER = -7.5 +_LAT_HALF_RANGE = 22.5 +_LON_CENTER = 27.5 +_LON_HALF_RANGE = 17.5 + class YieldAfricaDataset(BaseDataset): """Dataset for the crop yield regression use case (East/Southern Africa). @@ -45,11 +56,20 @@ class YieldAfricaDataset(BaseDataset): the model-ready CSV and are picked up via the `feat_` column prefix. They do NOT need to be listed in `modalities`. - In addition to the CSV feat_* columns, `year` and one-hot `country` - encodings are injected as `feat_year` and `feat_country_{CODE}` so that - the model can condition on inter-annual and cross-country variation. - The one-hot set always covers `_ALL_COUNTRIES` (8 countries) so that - `tabular_dim` is stable regardless of the country filter applied. + In addition to the CSV feat_* columns, the following features are injected: + - ``feat_year`` : normalised year (zero-mean, unit-std) + - ``feat_country_{CODE}`` : one-hot country encoding (always 8 columns, + stable across country filters) + - ``feat_lat_sin1/cos1`` : fundamental latitude harmonic, normalised to + the study-area extent (30°S–15°N) + - ``feat_lat_sin2/cos2`` : second latitude harmonic (captures bimodal vs. + unimodal rainfall boundary near the equator) + - ``feat_lon_sin1/cos1`` : fundamental longitude harmonic, normalised to + the study-area extent (10°E–45°E) + + The Fourier harmonics encode the ITCZ-driven latitudinal climate gradient at + interpretable frequencies, complementing GeoCLIP's photo-derived coordinate + embedding and enabling richer text captions for the explainability component. """ def __init__( @@ -93,6 +113,28 @@ def __init__( } for code in _ALL_COUNTRIES: new_cols[f"feat_country_{code}"] = (self.df["country"] == code).astype(float) + + # Fourier harmonics of coordinates, normalised to the study-area extent. + # + # Africa's agricultural patterns follow the ITCZ-driven latitudinal climate + # gradient: rainfall regime (uni- vs. bimodal), growing-season length, and + # temperature vary sinusoidally with latitude. Explicit harmonics give the + # model these signals directly and at interpretable frequencies, complementing + # GeoCLIP's learned (but photo-derived) coordinate embedding. + # + # lat_norm / lon_norm ∈ [-1, 1] within the study area; π * norm ∈ [-π, π]. + # Two harmonics for latitude (captures both the broad N-S gradient and the + # equatorial-bimodal / southern-unimodal boundary); one for longitude + # (east-west Indian Ocean moisture gradient). + lat_norm = (self.df["lat"].astype(float) - _LAT_CENTER) / _LAT_HALF_RANGE + lon_norm = (self.df["lon"].astype(float) - _LON_CENTER) / _LON_HALF_RANGE + new_cols["feat_lat_sin1"] = np.sin(np.pi * lat_norm) + new_cols["feat_lat_cos1"] = np.cos(np.pi * lat_norm) + new_cols["feat_lat_sin2"] = np.sin(2.0 * np.pi * lat_norm) + new_cols["feat_lat_cos2"] = np.cos(2.0 * np.pi * lat_norm) + new_cols["feat_lon_sin1"] = np.sin(np.pi * lon_norm) + new_cols["feat_lon_cos1"] = np.cos(np.pi * lon_norm) + self.df = pd.concat([self.df, pd.DataFrame(new_cols, index=self.df.index)], axis=1) # Apply country/year filters to self.df and rebuild records. From 9a62cebc7ec65b94f0a865921cd895cab62f9cea Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 10:51:49 +0100 Subject: [PATCH 28/78] Fix encoder tests --- .../components/pred_heads/base_pred_head.py | 4 ++-- src/models/text_alignment_model.py | 1 - tests/test_geo_encoders.py | 16 ++++++++++++---- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/models/components/pred_heads/base_pred_head.py b/src/models/components/pred_heads/base_pred_head.py index 3615fa5..5ed63c4 100644 --- a/src/models/components/pred_heads/base_pred_head.py +++ b/src/models/components/pred_heads/base_pred_head.py @@ -25,10 +25,10 @@ def set_dim(self, input_dim: int, output_dim: int) -> None: :param input_dim: input dimension :param output_dim: output dimension """ - assert isinstance(self.input_dim, int), TypeError( + assert isinstance(input_dim, int), TypeError( "Input dimension must be specified as integer" ) - assert isinstance(self.output_dim, int), TypeError( + assert isinstance(output_dim, int), TypeError( "Output dimension must be specified as integer" ) self.input_dim = input_dim diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 2630ff3..9cffc5e 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -6,7 +6,6 @@ from src.models.base_model import BaseModel from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder -from src.models.components.geo_encoders.multimodal_encoder import MultiModalEncoder from src.models.components.loss_fns.base_loss_fn import BaseLossFn from src.models.components.metrics.contrastive_validation import ( RetrievalContrastiveValidation, diff --git a/tests/test_geo_encoders.py b/tests/test_geo_encoders.py index 9a54732..57ff921 100644 --- a/tests/test_geo_encoders.py +++ b/tests/test_geo_encoders.py @@ -6,10 +6,10 @@ import torch from src.models.components.geo_encoders.average_encoder import AverageEncoder -from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder from src.models.components.geo_encoders.cnn_encoder import CNNEncoder from src.models.components.geo_encoders.geoclip import GeoClipCoordinateEncoder -from src.models.components.geo_encoders.multimodal_encoder import MultiModalEncoder +from src.models.components.geo_encoders.mlp_projector import MLPProjector +from src.models.components.geo_encoders.tabular_encoder import TabularEncoder # @pytest.mark.slow @@ -19,13 +19,21 @@ def test_geo_encoder_generic_properties(create_butterfly_dataset): "geoclip_coords": GeoClipCoordinateEncoder, "cnn": CNNEncoder, "average": AverageEncoder, - "multimodal_coords": MultiModalEncoder, + "tabular": TabularEncoder, + "mlp_projector": MLPProjector, } ds, dm = create_butterfly_dataset batch = next(iter(dm.train_dataloader())) for geo_encoder_name, geo_encoder_class in dict_geo_encoders.items(): - geo_encoder = geo_encoder_class() + if geo_encoder_class is MLPProjector: + geo_encoder = geo_encoder_class(output_dim=64, input_dim=128) + elif geo_encoder_class is TabularEncoder: + geo_encoder = geo_encoder_class(output_dim=64, input_dim=128, hidden_dim=128) + else: + geo_encoder = geo_encoder_class() + + geo_encoder.setup() assert hasattr( geo_encoder, "geo_encoder" From ec9338bbec17e803c6db9eb6b0f163a3adbcd4f3 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 12:02:45 +0100 Subject: [PATCH 29/78] fix tests --- .../components/metrics/contrastive_similarities.py | 4 ++-- tests/test_pred_heads.py | 12 +++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/models/components/metrics/contrastive_similarities.py b/src/models/components/metrics/contrastive_similarities.py index c3aae7e..4a72635 100644 --- a/src/models/components/metrics/contrastive_similarities.py +++ b/src/models/components/metrics/contrastive_similarities.py @@ -15,7 +15,7 @@ def __init__(self, k_list=None) -> None: def forward( self, mode: str, - eo_feats: torch.Tensor, + geo_feats: torch.Tensor, text_feats: torch.Tensor, local_batch_size: int, **kwargs, @@ -23,7 +23,7 @@ def forward( """Calculate cosine similarity between eo and text embeddings and logs it.""" # Similarity matrix - cos_sim_matrix = F.cosine_similarity(eo_feats[:, None, :], text_feats[None, :, :], dim=-1) + cos_sim_matrix = F.cosine_similarity(geo_feats[:, None, :], text_feats[None, :, :], dim=-1) # Average for positive and negative pairs # TODO change label option if we change what gets treated to be pos/neg diff --git a/tests/test_pred_heads.py b/tests/test_pred_heads.py index 15553ea..abc99b4 100644 --- a/tests/test_pred_heads.py +++ b/tests/test_pred_heads.py @@ -19,11 +19,13 @@ def test_pred_head_generic_properties(create_butterfly_dataset): ds, dm = create_butterfly_dataset batch = next(iter(dm.train_dataloader())) eo_encoder = GeoClipCoordinateEncoder() + eo_encoder.setup() feats = eo_encoder.forward(batch) list_pred_heads = [LinearPredictionHead, MLPPredictionHead, MLPRegressionPredictionHead] for pred_head_class in list_pred_heads: - pred_head = pred_head_class() + pred_head = pred_head_class(input_dim=64, output_dim=64) + pred_head.setup() assert hasattr( pred_head, "set_dim" ), f"'set_dim' method missing in {pred_head_class.__name__}." @@ -38,11 +40,11 @@ def test_pred_head_generic_properties(create_butterfly_dataset): pred_head, "output_dim" ), f"'output_dim' attribute missing in {pred_head_class.__name__}." assert hasattr( - pred_head, "configure_nn" - ), f"'configure_nn' method missing in {pred_head_class.__name__}." + pred_head, "setup" + ), f"'setup' method missing in {pred_head_class.__name__}." assert callable( - getattr(pred_head, "configure_nn") - ), f"'configure_nn' is not callable in {pred_head_class.__name__}." + getattr(pred_head, "setup") + ), f"'setup' is not callable in {pred_head_class.__name__}." pred_head.setup() assert hasattr(pred_head, "net"), f"'net' attribute missing in {pred_head_class.__name__}." assert hasattr( From 444fa0c1ba2bc5ce2647d86225834ff7cd0f7f89 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 12:17:49 +0100 Subject: [PATCH 30/78] fix tests --- src/models/base_model.py | 8 ++++---- src/models/predictive_model.py | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/models/base_model.py b/src/models/base_model.py index a28c085..89fbaf6 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -45,7 +45,7 @@ def setup(self, stage: str) -> None: @final def freezer(self) -> None: """Freezes modules based on provided trainable modules.""" - self.trainable_modules = tuple(self.trainable_modules) or tuple() + trainable_modules = tuple(self.trainable_modules) or tuple() # Store higher level module names for printing of trainable parts trainable = set() @@ -53,7 +53,7 @@ def freezer(self) -> None: # Freeze modules for name, param in self.named_parameters(): # Enable exceptions - if name.startswith(self.trainable_modules): + if name.startswith(trainable_modules): param.requires_grad = True trainable.add(name) else: @@ -69,8 +69,8 @@ def freezer(self) -> None: # - it is the root module (""), which must be train when any child is. def _in_train_scope(name: str) -> bool: if not name: # root module - return bool(self.trainable_modules) - for t in self.trainable_modules: + return bool(trainable_modules) + for t in trainable_modules: if name == t or name.startswith(t + ".") or t.startswith(name + "."): return True return False diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index 16bae50..b46a3f4 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -58,6 +58,10 @@ def setup(self, stage: str) -> None: self.num_classes = self.trainer.datamodule.num_classes self.tabular_dim = self.trainer.datamodule.tabular_dim + if stage != "fit": + if isinstance(self.trainable_modules, tuple): + self.trainable_modules = list(self.trainable_modules) + self.setup_encoders_adapters() # Freezing requested parts From d3af3e42c64699ef9acd44d1bc883ebd95c6dccd Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 14:52:30 +0100 Subject: [PATCH 31/78] Fix depth of summary report for modules --- configs/callbacks/default.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml index 149a92f..e161020 100644 --- a/configs/callbacks/default.yaml +++ b/configs/callbacks/default.yaml @@ -19,4 +19,4 @@ early_stopping: mode: "min" model_summary: - max_depth: 2 + max_depth: 1 From 1852834d7fd769eb7b925ea11240939b27d82f19 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Sun, 15 Mar 2026 13:22:22 +0100 Subject: [PATCH 32/78] fix value 0 being ignored --- src/models/components/metrics/contrastive_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/components/metrics/contrastive_validation.py b/src/models/components/metrics/contrastive_validation.py index ce2385a..76d216a 100644 --- a/src/models/components/metrics/contrastive_validation.py +++ b/src/models/components/metrics/contrastive_validation.py @@ -39,7 +39,7 @@ def forward( k_threshold = configs.get("theta_k") aux_val = aux_vals[idx] - if k_threshold: + if k_threshold is not None: dynamic_k = ( sum(aux_val >= k_threshold).item() if is_max From e725767e8979a56f3d56973ec1d1d9204c154818 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Sun, 15 Mar 2026 13:24:19 +0100 Subject: [PATCH 33/78] Add model set up print statements --- src/models/components/geo_encoders/average_encoder.py | 2 ++ src/models/components/geo_encoders/cnn_encoder.py | 1 + src/models/components/geo_encoders/geoclip.py | 1 + src/models/components/geo_encoders/mlp_projector.py | 1 + src/models/components/geo_encoders/tabular_encoder.py | 1 + src/models/components/text_encoders/clip_text_encoder.py | 5 +++++ src/models/predictive_model.py | 2 ++ src/models/text_alignment_model.py | 2 ++ 8 files changed, 15 insertions(+) diff --git a/src/models/components/geo_encoders/average_encoder.py b/src/models/components/geo_encoders/average_encoder.py index 885b175..0c0eaf6 100644 --- a/src/models/components/geo_encoders/average_encoder.py +++ b/src/models/components/geo_encoders/average_encoder.py @@ -34,6 +34,8 @@ def setup(self) -> List[str]: """ self.output_dim = self.dict_n_bands_default[self.geo_data_name] self.geo_encoder = nn.Identity() + print(f"Model set up with average geo-encoder for {self.geo_data_name}") + return [] @override diff --git a/src/models/components/geo_encoders/cnn_encoder.py b/src/models/components/geo_encoders/cnn_encoder.py index d73a7b3..34f8f48 100644 --- a/src/models/components/geo_encoders/cnn_encoder.py +++ b/src/models/components/geo_encoders/cnn_encoder.py @@ -135,6 +135,7 @@ def get_backbone(self): def setup(self) -> List[str]: # TODO: could you make sure new layers are returned here to be added to trainable parts? # Maybe move the get_backbone method in here? + print(f"Model setup with cnn geo-encoder for {self.geo_data_name}") return [] @override diff --git a/src/models/components/geo_encoders/geoclip.py b/src/models/components/geo_encoders/geoclip.py index e530dd1..38dffdb 100644 --- a/src/models/components/geo_encoders/geoclip.py +++ b/src/models/components/geo_encoders/geoclip.py @@ -24,6 +24,7 @@ def __init__( def setup(self) -> List[str]: self.geo_encoder = LocationEncoder() self.output_dim = self.geo_encoder.LocEnc0.head[0].out_features + print("Model setup with GeoClip coordinate encoder") return [] @override diff --git a/src/models/components/geo_encoders/mlp_projector.py b/src/models/components/geo_encoders/mlp_projector.py index f4c1137..e622216 100644 --- a/src/models/components/geo_encoders/mlp_projector.py +++ b/src/models/components/geo_encoders/mlp_projector.py @@ -27,6 +27,7 @@ def __init__( @override def setup(self) -> List[str]: self.configure_nn() + print("Model setup with MLP projector") return ["net"] def set_input_dim(self, input_dim: int) -> None: diff --git a/src/models/components/geo_encoders/tabular_encoder.py b/src/models/components/geo_encoders/tabular_encoder.py index 09af55c..1ae4b1d 100644 --- a/src/models/components/geo_encoders/tabular_encoder.py +++ b/src/models/components/geo_encoders/tabular_encoder.py @@ -35,6 +35,7 @@ def __init__( @override def setup(self, input_dim: int = None) -> list[str]: self.configure_nn(input_dim) + print("Model setup with Tabular geo-encoder") return ["tabular_encoder"] def set_tabular_input_dim(self, input_dim: int) -> None: diff --git a/src/models/components/text_encoders/clip_text_encoder.py b/src/models/components/text_encoders/clip_text_encoder.py index 7642865..075a390 100644 --- a/src/models/components/text_encoders/clip_text_encoder.py +++ b/src/models/components/text_encoders/clip_text_encoder.py @@ -24,8 +24,13 @@ def __init__(self, hf_cache_dir: str = "../.cache", output_normalization="l2") - self.projector = GeoCLIP().image_encoder.mlp + self.model.vision_model = None + self.model.visual_projection = None + self.output_dim = 512 + print("Model set up with CLIP text encoder") + @override def forward(self, batch: Dict[str, torch.Tensor], mode: str) -> torch.Tensor: # Get text inputs diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index b46a3f4..7b951c2 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -62,7 +62,9 @@ def setup(self, stage: str) -> None: if isinstance(self.trainable_modules, tuple): self.trainable_modules = list(self.trainable_modules) + print("-------Model------------") self.setup_encoders_adapters() + print("------------------------") # Freezing requested parts self.freezer() diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 9cffc5e..49b8a91 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -69,7 +69,9 @@ def setup(self, stage: str) -> None: self.tabular_dim = self.trainer.datamodule.tabular_dim # Set up encoders and missing adapters/projectors + print("-------Model------------") self.setup_encoders_adapters() + print("------------------------") # Freeze requested parts self.freezer() From ec0ba03b685ffc47f8bdf9dc74cccdb93285ade0 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Thu, 12 Mar 2026 15:44:45 +0100 Subject: [PATCH 34/78] Crop Yield use case: Reduced MLP projector, equal contribution of spatial and tabular encoders (for now). --- configs/model/yield_tessera_fusion_reg.yaml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/configs/model/yield_tessera_fusion_reg.yaml b/configs/model/yield_tessera_fusion_reg.yaml index c30b6af..1e8c175 100644 --- a/configs/model/yield_tessera_fusion_reg.yaml +++ b/configs/model/yield_tessera_fusion_reg.yaml @@ -14,9 +14,8 @@ geo_encoder: geo_data_name: tessera projector: _target_: src.models.components.geo_encoders.mlp_projector.MLPProjector - nn_layers: 2 - hidden_dim: 512 - output_dim: 512 + nn_layers: 1 + output_dim: 256 - encoder: _target_: src.models.components.geo_encoders.tabular_encoder.TabularEncoder output_dim: 256 From 1d9953d8c07a87113f5428e061b473137a612be5 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Sun, 15 Mar 2026 13:46:00 +0100 Subject: [PATCH 35/78] Guatemala UC tessera --- pyproject.toml | 1 + src/data/heat_guatemala_dataset.py | 10 +++++++++- src/data_preprocessing/tessera_embeds.py | 12 +++++++++++- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4c51d60..910e924 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "peft>=0.18.1", "llm2vec", "setuptools<81", + "geotessera>=0.7.3", ] [project.optional-dependencies] diff --git a/src/data/heat_guatemala_dataset.py b/src/data/heat_guatemala_dataset.py index c20b686..f4fabac 100644 --- a/src/data/heat_guatemala_dataset.py +++ b/src/data/heat_guatemala_dataset.py @@ -56,7 +56,7 @@ def __init__( dataset_name="heat_guatemala", seed=seed, cache_dir=cache_dir, - implemented_mod={"coords"}, + implemented_mod={"coords", "tessera"}, mock=mock, use_features=use_features, ) @@ -67,6 +67,14 @@ def __init__( def setup(self) -> None: """No files to download / prepare for this dataset.""" + # Set up each requested modality + for mod in self.modalities.keys(): + if mod == "coords" and len(self.modalities.keys()) == 1: + return + elif mod == "tessera": + self.setup_tessera() + # elif mod == "aef": + # self.setup_aef() return @override diff --git a/src/data_preprocessing/tessera_embeds.py b/src/data_preprocessing/tessera_embeds.py index fcc6c32..c2414b8 100644 --- a/src/data_preprocessing/tessera_embeds.py +++ b/src/data_preprocessing/tessera_embeds.py @@ -199,7 +199,7 @@ def tessera_from_df( # Tessera connection cache_dir = os.path.join(cache_dir, "tessera") - gt = GeoTessera(cache_dir=cache_dir) + gt = GeoTessera(cache_dir=cache_dir, embeddings_dir=cache_dir, dataset_version="v1") # Iter each coord n = len(model_ready_df) @@ -267,3 +267,13 @@ def inspect_np_arr_as_tiff( dst.write(arr_to_write[i], i + 1) print(f"Tiff version of np array saved to {file_path}") + + +if __name__ == "__main__": + os.chdir("../..") + + df = pd.read_csv("data/heat_guatemala/model_ready_heat_guatemala.csv") + + tessera_from_df( + df, "data/heat_guatemala/eo/tessera_2024", year=2024, tile_size=10, cache_dir="data/cache" + ) From 62d063d2cc7823d217ac83ae2d27e0ad5d2d9d1d Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Sun, 15 Mar 2026 13:48:23 +0100 Subject: [PATCH 36/78] Alignment training --- configs/paths/shared.yaml | 4 ++-- scripts/schedule.sh | 20 ++++++++++++-------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/configs/paths/shared.yaml b/configs/paths/shared.yaml index 76e6046..a4702b0 100644 --- a/configs/paths/shared.yaml +++ b/configs/paths/shared.yaml @@ -5,8 +5,8 @@ root_dir: ${oc.env:PROJECT_ROOT,./} # path to data directory -data_dir: ${oc.env:DATA_DIR,oc.env:SHARED_ROOT/data/,${paths.root_dir}/data/} -cache_dir: ${oc.env:CACHE_DIR,${paths.data_dir}/cache} +data_dir: ${oc.env:DATA_DIR,${oc.env:SHARED_ROOT,${paths.root_dir}}/data} +cache_dir: ${oc.env:SHARED_CACHE,${paths.data_dir}/cache} # path to logging directory log_dir: ${oc.env:SHARED_ROOT}/logs/ diff --git a/scripts/schedule.sh b/scripts/schedule.sh index 8054b63..8bab438 100644 --- a/scripts/schedule.sh +++ b/scripts/schedule.sh @@ -1,23 +1,27 @@ #!/bin/bash -#SBATCH--cpus-per-task=8 -#SBATCH--partition=gpu -#SBATCH--gpus=1 -#SBATCH--job-name=aether -#SBATCH--mem=100G -#SBATCH--time=100 +#SBATCH --cpus-per-task=8 +#SBATCH --partition=gpu +#SBATCH --gpus=1 +#SBATCH --job-name=aether +#SBATCH --mem=100G +#SBATCH --time=100 +#SBATCH --output=logs/out_%j.out +#SBATCH --error=logs/err_%j.err # Schedule execution of many runs # Run from root folder with: bash scripts/schedule.sh # Variables +# shellcheck disable=SC1091 source .env -# Environment +#Environment +# shellcheck disable=SC1091 source .venv/bin/activate # Runs #srun python src/train.py experiment=alignment trainer=$TRAINER_PROFILE logger=$LOGGER -#srun python src/train.py experiment=prediction logger=wandb +srun python -u src/train.py experiment=alignment_v1 # example runs with overwritten configs #srun python src/train.py experiment=alignment trainer=ddp_sim trainer.max_epochs=10 data.pin_memory=false From a0ec8172d54c58a33a9d4563ce71cff4a5bc37eb Mon Sep 17 00:00:00 2001 From: gabriele Date: Thu, 19 Mar 2026 13:45:08 +0100 Subject: [PATCH 37/78] De-duplicate geotessera requirements Remove from optional --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 910e924..f49df39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,9 +41,6 @@ create-data = [ "geemap>=0.36.6", "pipreqs>=0.5.0", ] -geotessera = [ - "geotessera>=0.7.3", -] [tool.pytest.ini_options] addopts = [ From 12ce6a8addab548b5f31b8d6f9bcdb5283989ad6 Mon Sep 17 00:00:00 2001 From: gabriele Date: Thu, 19 Mar 2026 13:47:06 +0100 Subject: [PATCH 38/78] Create input and output dimensions as attributes --- src/models/components/pred_heads/linear_pred_head.py | 3 +-- src/models/components/pred_heads/mlp_pred_head.py | 3 +-- src/models/components/pred_heads/mlp_regression_head.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/models/components/pred_heads/linear_pred_head.py b/src/models/components/pred_heads/linear_pred_head.py index 94338a7..61efc7d 100644 --- a/src/models/components/pred_heads/linear_pred_head.py +++ b/src/models/components/pred_heads/linear_pred_head.py @@ -18,8 +18,7 @@ def __init__( :param output_dim: the size of output dimension """ super().__init__() - if input_dim and output_dim: - self.set_dim(input_dim, output_dim) + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: diff --git a/src/models/components/pred_heads/mlp_pred_head.py b/src/models/components/pred_heads/mlp_pred_head.py index 144d602..282b3bf 100644 --- a/src/models/components/pred_heads/mlp_pred_head.py +++ b/src/models/components/pred_heads/mlp_pred_head.py @@ -24,8 +24,7 @@ def __init__( super().__init__() self.nn_layers = nn_layers self.hidden_dim = hidden_dim - if input_dim and output_dim: - self.set_dim(input_dim, output_dim) + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: diff --git a/src/models/components/pred_heads/mlp_regression_head.py b/src/models/components/pred_heads/mlp_regression_head.py index a179efe..6894bb4 100644 --- a/src/models/components/pred_heads/mlp_regression_head.py +++ b/src/models/components/pred_heads/mlp_regression_head.py @@ -39,8 +39,7 @@ def __init__( self.nn_layers = nn_layers self.hidden_dim = hidden_dim self.dropout = dropout - if input_dim and output_dim: - self.set_dim(input_dim, output_dim) + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: From f0e4a1fdaa140dcb601b9f70cd88d2693eeaddfc Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Mon, 23 Mar 2026 16:19:30 +0100 Subject: [PATCH 39/78] Fix broken tests --- src/models/components/pred_heads/linear_pred_head.py | 3 ++- src/models/components/pred_heads/mlp_pred_head.py | 4 +++- src/models/components/pred_heads/mlp_regression_head.py | 4 +++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/models/components/pred_heads/linear_pred_head.py b/src/models/components/pred_heads/linear_pred_head.py index 61efc7d..94338a7 100644 --- a/src/models/components/pred_heads/linear_pred_head.py +++ b/src/models/components/pred_heads/linear_pred_head.py @@ -18,7 +18,8 @@ def __init__( :param output_dim: the size of output dimension """ super().__init__() - self.set_dim(input_dim, output_dim) + if input_dim and output_dim: + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: diff --git a/src/models/components/pred_heads/mlp_pred_head.py b/src/models/components/pred_heads/mlp_pred_head.py index 282b3bf..6d4c124 100644 --- a/src/models/components/pred_heads/mlp_pred_head.py +++ b/src/models/components/pred_heads/mlp_pred_head.py @@ -24,7 +24,9 @@ def __init__( super().__init__() self.nn_layers = nn_layers self.hidden_dim = hidden_dim - self.set_dim(input_dim, output_dim) + + if input_dim and output_dim: + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: diff --git a/src/models/components/pred_heads/mlp_regression_head.py b/src/models/components/pred_heads/mlp_regression_head.py index 6894bb4..d9553ec 100644 --- a/src/models/components/pred_heads/mlp_regression_head.py +++ b/src/models/components/pred_heads/mlp_regression_head.py @@ -39,7 +39,9 @@ def __init__( self.nn_layers = nn_layers self.hidden_dim = hidden_dim self.dropout = dropout - self.set_dim(input_dim, output_dim) + + if input_dim and output_dim: + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: From 5ca33e6e456af83e317370e3440accfd13737251 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 25 Mar 2026 21:43:13 +0100 Subject: [PATCH 40/78] Log dynamic k per concept --- src/models/text_alignment_model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 49b8a91..ad94fd5 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -117,6 +117,10 @@ def setup_encoders_adapters(self): def setup_retrieval_evaluation(self): self.concept_configs = self.trainer.datamodule.concept_configs self.concepts = [c["concept_caption"] for c in self.concept_configs] + self.concept_names = [ + f"{c['col'].replace('aux_', '')}_{'max' if c['is_max'] else 'min'}" + for c in self.concept_configs + ] self.contrastive_val = RetrievalContrastiveValidation(self.ks, self.concept_configs) self.outputs_epoch_memory = [] @@ -214,6 +218,8 @@ def _on_epoch_end(self, mode: str): for i, result in concept_scores.items(): print(f'\nConcept "{self.concepts[i]}" average top-k accuracies in {mode} split:') for k, v in result.items(): + if k == "dynamic_k": + self.log(f"dyn_k_{self.self.concept_names[i]}", v, **self.log_kwargs) print(f"Top-{k}: {v:.1f}%") avr_scores[f"{mode}_avr_top-{k}"].append(v) From ae532cbeefbe509a518f85622c4566e61dba99c2 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 25 Mar 2026 21:43:23 +0100 Subject: [PATCH 41/78] valuation configs --- configs/eval.yaml | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/configs/eval.yaml b/configs/eval.yaml index be31299..80436cd 100644 --- a/configs/eval.yaml +++ b/configs/eval.yaml @@ -1,18 +1,43 @@ # @package _global_ +# specify here default configuration +# order of defaults determines the order in which configs override each other defaults: - _self_ - - data: mnist # choose datamodule with `test_dataloader()` for evaluation - - model: mnist - - logger: null - - trainer: default - - paths: default + - data: butterfly_coords + - model: predictive_geoclip + - callbacks: default + - logger: ${oc.env:LOGGER,wandb} + - trainer: ${oc.env:TRAINER_PROFILE,default} + - paths: ${oc.env:STORAGE_MODE,local} - extras: default - hydra: default + - metrics: butterfly_predictive + # experiment configs allow for version control of specific hyperparameters + # e.g. best hyperparameters for given model and datamodule + - experiment: null + + # config for hyperparameter optimization + - hparams_search: null + + # optional local config for machine/user specific settings + # it's optional since it doesn't need to exist and is excluded from version control + - optional local: default + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: null + +# task name, determines output directory path task_name: "eval" +# tags to help you identify your experiments +# you can overwrite this in experiment configs +# overwrite from command line with `python train.py tags="[first_tag, second_tag]"` tags: ["dev"] +# seed for random number generators in pytorch, numpy and python.random +seed: 12345 + # passing checkpoint path is necessary for evaluation -ckpt_path: ??? +ckpt_path: aether/logs/train/runs/2026-03-15_13-15-08/checkpoints/epoch_098.ckpt From eff9ce531d158b8f9724a0bde4f648fcae223de6 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 25 Mar 2026 21:44:25 +0100 Subject: [PATCH 42/78] alignment poster inspection code --- notebooks/08-GT-aligment-visualisation.ipynb | 330 +++++++++++++++++++ 1 file changed, 330 insertions(+) create mode 100644 notebooks/08-GT-aligment-visualisation.ipynb diff --git a/notebooks/08-GT-aligment-visualisation.ipynb b/notebooks/08-GT-aligment-visualisation.ipynb new file mode 100644 index 0000000..f8b4e9d --- /dev/null +++ b/notebooks/08-GT-aligment-visualisation.ipynb @@ -0,0 +1,330 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from typing import Any, Dict, List, Tuple\n", + "\n", + "import hydra\n", + "import torch\n", + "from lightning import LightningDataModule, LightningModule, Trainer\n", + "from lightning.pytorch.loggers import Logger\n", + "\n", + "if os.environ.get(\"TOKENIZERS_PARALLELISM\") is None:\n", + " os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "\n", + "from src.utils import (\n", + " instantiate_loggers,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "os.chdir(\"..\")\n", + "os.getcwd()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "from hydra import compose, initialize\n", + "from hydra.core.hydra_config import HydraConfig\n", + "\n", + "with initialize(version_base=\"1.3\", config_path=\"../configs\"):\n", + " cfg = compose(\n", + " config_name=\"eval.yaml\",\n", + " return_hydra_config=True,\n", + " overrides=[\"experiment=alignment_v1\"], # 👈 loads configs/experiment/my_exp.yaml\n", + " )\n", + "\n", + "HydraConfig.instance().set_config(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "assert cfg.ckpt_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)\n", + "\n", + "model: LightningModule = hydra.utils.instantiate(cfg.model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "logger: List[Logger] = instantiate_loggers(cfg.get(\"logger\"))\n", + "\n", + "cfg[\"trainer\"][\"max_epochs\"] = 0\n", + "cfg[\"trainer\"][\"accelerator\"] = \"mps\"\n", + "trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)\n", + "\n", + "object_dict = {\n", + " \"cfg\": cfg,\n", + " \"datamodule\": datamodule,\n", + " \"model\": model,\n", + " \"logger\": logger,\n", + " \"trainer\": trainer,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.strategy.connect(model)\n", + "trainer._data_connector.attach_data(model, datamodule=datamodule)\n", + "\n", + "model.setup(stage=\"fit\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "path = \"logs/train/runs/2026-03-15_13-15-08/checkpoints/epoch_098.ckpt\"\n", + "ckpt = torch.load(path, map_location=\"cpu\", weights_only=False)\n", + "model.load_state_dict(ckpt[\"state_dict\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "data = datamodule.test_dataloader().dataset[2]\n", + "data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "\n", + "def show_rgb(tensor, stretch=True):\n", + " \"\"\"\n", + " tensor: numpy array or torch tensor of shape (C, H, W)\n", + " bands: tuple of 3 indices (R, G, B)\n", + " stretch: whether to normalize values for display\n", + " \"\"\"\n", + " # Convert to numpy if it's a torch tensor\n", + " if \"torch\" in str(type(tensor)):\n", + " tensor = tensor.detach().cpu().numpy()\n", + "\n", + " # Select bands\n", + " rgb = tensor[:3, :, :]\n", + "\n", + " # Move to (H, W, 3)\n", + " rgb = np.transpose(rgb, (1, 2, 0))\n", + "\n", + " # Normalize for visualization\n", + " if stretch:\n", + " rgb = rgb.astype(np.float32)\n", + " for i in range(3):\n", + " band = rgb[:, :, i]\n", + " band_min, band_max = band.min(), band.max()\n", + " if band_max > band_min:\n", + " rgb[:, :, i] = (band - band_min) / (band_max - band_min)\n", + "\n", + " plt.imshow(rgb)\n", + " plt.axis(\"off\")\n", + " plt.show()\n", + "\n", + "\n", + "show_rgb(data[\"eo\"][\"aef\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "batch = data\n", + "batch[\"eo\"][\"aef\"] = batch[\"eo\"][\"aef\"].unsqueeze(0)\n", + "batch[\"eo\"][\"aef\"].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "geo_feats = model.geo_encoder(data).squeeze(1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "concepts = [\n", + " \"Densely populated area with many houses\",\n", + " \"Very sparsely populated area with few houses\",\n", + " \"Arable land with crops for agriculture\",\n", + " \"Pasture fields with grass for grazing animals\",\n", + " \"Forested area with many trees\",\n", + " \"Water bodies such as lakes, rivers and sea\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "text_tokens = model.text_encoder.processor(\n", + " text=concepts,\n", + " return_tensors=\"pt\",\n", + " padding=True,\n", + " truncation=True,\n", + " max_length=77,\n", + ")\n", + "\n", + "device = model.text_encoder.device\n", + "text_tokens = {k: v.to(device) for k, v in text_tokens.items()}\n", + "\n", + "text_embeds = model.text_encoder.model.get_text_features(**text_tokens)\n", + "\n", + "# Project\n", + "if model.text_encoder.projector is not None:\n", + " text_embeds = model.text_encoder.projector(text_embeds)\n", + "\n", + "if model.text_encoder.extra_projector is not None:\n", + " text_embeds = model.text_encoder.extra_projector(text_embeds)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "text_embeds.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "def cosine_scores(x, Y):\n", + " \"\"\"\n", + " x: tensor of shape (1, 512)\n", + " Y: tensor of shape (6, 512)\n", + " returns: tensor of shape (6,)\n", + " \"\"\"\n", + " # Normalize\n", + " x = F.normalize(x, dim=1) # (1, 512)\n", + " Y = F.normalize(Y, dim=1) # (6, 512)\n", + "\n", + " # Compute cosine similarity\n", + " scores = torch.matmul(Y, x.T) # (6, 1)\n", + "\n", + " return scores.squeeze(1)\n", + "\n", + "\n", + "cosine_scores(geo_feats, text_embeds.cpu())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "scores = {\n", + " k: round(v.detach().item(), 2)\n", + " for k, v in zip(concepts, cosine_scores(geo_feats, text_embeds.cpu()))\n", + "}\n", + "scores" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 8b0b1853919b69d8dbfecbddf7b73df811abe0a5 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 31 Mar 2026 13:31:40 +0200 Subject: [PATCH 43/78] fix self.self --- src/models/text_alignment_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index ad94fd5..2cc1048 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -219,7 +219,7 @@ def _on_epoch_end(self, mode: str): print(f'\nConcept "{self.concepts[i]}" average top-k accuracies in {mode} split:') for k, v in result.items(): if k == "dynamic_k": - self.log(f"dyn_k_{self.self.concept_names[i]}", v, **self.log_kwargs) + self.log(f"dyn_k_{self.concept_names[i]}", v, **self.log_kwargs) print(f"Top-{k}: {v:.1f}%") avr_scores[f"{mode}_avr_top-{k}"].append(v) From 1948ae26bc4e0980ed14237ee00c1af287be2838 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 31 Mar 2026 13:32:09 +0200 Subject: [PATCH 44/78] Remove redundant val_val in logging the metric names --- src/models/components/metrics/metrics_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/components/metrics/metrics_wrapper.py b/src/models/components/metrics/metrics_wrapper.py index 33d86de..0d40212 100644 --- a/src/models/components/metrics/metrics_wrapper.py +++ b/src/models/components/metrics/metrics_wrapper.py @@ -19,6 +19,6 @@ def forward(self, mode="train", **kwargs) -> Dict[str, torch.float]: for metric in self.metrics: metric_results = metric(mode=mode, return_label=True, **kwargs) for k, v in metric_results.items(): - compiled_dict[f"{mode}_{k}"] = v + compiled_dict[k] = v return compiled_dict From c89bfd875c4b90e8f6cbc31597018f17298f8e6f Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 2 Apr 2026 13:21:58 +0200 Subject: [PATCH 45/78] Add missing pass through the extra projector --- src/models/components/geo_encoders/average_encoder.py | 2 ++ src/models/components/geo_encoders/cnn_encoder.py | 3 +++ src/models/components/geo_encoders/geoclip.py | 2 ++ src/models/components/geo_encoders/mlp_projector.py | 5 ++++- src/models/components/geo_encoders/tabular_encoder.py | 3 +++ 5 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/models/components/geo_encoders/average_encoder.py b/src/models/components/geo_encoders/average_encoder.py index 0c0eaf6..b11e3ce 100644 --- a/src/models/components/geo_encoders/average_encoder.py +++ b/src/models/components/geo_encoders/average_encoder.py @@ -43,4 +43,6 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """Data forward pass through the encoder.""" tile = batch.get("eo", {}).get(self.geo_data_name) feats = self.geo_encoder(tile.mean(dim=(-2, -1))) + if self.extra_projector: + feats = self.extra_projector(feats) return feats diff --git a/src/models/components/geo_encoders/cnn_encoder.py b/src/models/components/geo_encoders/cnn_encoder.py index 34f8f48..6ec3609 100644 --- a/src/models/components/geo_encoders/cnn_encoder.py +++ b/src/models/components/geo_encoders/cnn_encoder.py @@ -159,6 +159,9 @@ def forward( # n_nans == 0 # ), f"CNNEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.geo_data_name].min()} and max {eo_data[self.geo_data_name].max()}." + if self.extra_projector: + feats = self.extra_projector(feats) + return feats.to(dtype) diff --git a/src/models/components/geo_encoders/geoclip.py b/src/models/components/geo_encoders/geoclip.py index 38dffdb..cd246ef 100644 --- a/src/models/components/geo_encoders/geoclip.py +++ b/src/models/components/geo_encoders/geoclip.py @@ -39,6 +39,8 @@ def forward( if coords.dtype != dtype: coords = coords.to(dtype) feats = self.geo_encoder(coords) + if self.extra_projector: + feats = self.extra_projector(feats) return feats.to(dtype) diff --git a/src/models/components/geo_encoders/mlp_projector.py b/src/models/components/geo_encoders/mlp_projector.py index e622216..6c27345 100644 --- a/src/models/components/geo_encoders/mlp_projector.py +++ b/src/models/components/geo_encoders/mlp_projector.py @@ -49,4 +49,7 @@ def configure_nn(self) -> None: self.net = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.net(x) + feats = self.net(x) + if self.extra_projector: + feats = self.extra_projector(feats) + return feats diff --git a/src/models/components/geo_encoders/tabular_encoder.py b/src/models/components/geo_encoders/tabular_encoder.py index 1ae4b1d..314cf29 100644 --- a/src/models/components/geo_encoders/tabular_encoder.py +++ b/src/models/components/geo_encoders/tabular_encoder.py @@ -68,4 +68,7 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: tab_data = tab_data.to(dtype) feats = self.geo_encoder(tab_data) + if self.extra_projector: + feats = self.extra_projector(feats) + return feats.to(dtype) From f6f6c342e29b47bb381540f055587b3ecfd105dc Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 2 Apr 2026 13:23:09 +0200 Subject: [PATCH 46/78] Add new type of error for missing file specification --- src/utils/errors.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/utils/errors.py b/src/utils/errors.py index 910dea5..aa88ce0 100644 --- a/src/utils/errors.py +++ b/src/utils/errors.py @@ -1,2 +1,10 @@ class IllegalArgumentCombination(ValueError): + """Error for illogical argument configuration.""" + + pass + + +class FileNotSpecified(ValueError): + """Error for missing file path specification.""" + pass From be81881ec534e3f55781a74f1652b1e00f38e8d3 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 2 Apr 2026 13:23:40 +0200 Subject: [PATCH 47/78] Add missing pass through the extra projector --- src/models/components/geo_encoders/encoder_wrapper.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/models/components/geo_encoders/encoder_wrapper.py b/src/models/components/geo_encoders/encoder_wrapper.py index 9cdf8ad..0523619 100644 --- a/src/models/components/geo_encoders/encoder_wrapper.py +++ b/src/models/components/geo_encoders/encoder_wrapper.py @@ -122,8 +122,14 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: branch_feats.append(feats) if self.fusion_strategy == "concat": - return torch.cat(branch_feats, dim=1) - return torch.mean(branch_feats, dim=1) + feats = torch.cat(branch_feats, dim=1) + if self.extra_projector: + feats = self.extra_projector(feats) + else: + feats = torch.cat(branch_feats, dim=1) + if self.extra_projector: + feats = self.extra_projector(feats) + return feats @property def device(self): From b1595f9d59a677c3a0b9c2b55e3243a5ed08c189 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 2 Apr 2026 15:25:42 +0200 Subject: [PATCH 48/78] fix if self.extra_projector missing --- src/models/components/geo_encoders/base_geo_encoder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/models/components/geo_encoders/base_geo_encoder.py b/src/models/components/geo_encoders/base_geo_encoder.py index f79f3f4..6b65c75 100644 --- a/src/models/components/geo_encoders/base_geo_encoder.py +++ b/src/models/components/geo_encoders/base_geo_encoder.py @@ -14,6 +14,7 @@ def __init__(self) -> None: # placeholders self.allowed_geo_data_names: list[str] | None = None self.geo_data_name: str | None = None + self.extra_projector: nn.Module | None = None @abstractmethod def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: From e7469cd84e5d3495ad7f7a49a67e447094f5fad3 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:09:43 +0200 Subject: [PATCH 49/78] Formating hook changes --- .../yield_africa_spatial_splits.py | 13 +++++-------- src/models/components/loss_fns/rrmse_loss.py | 12 +++++------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/data_preprocessing/yield_africa_spatial_splits.py b/src/data_preprocessing/yield_africa_spatial_splits.py index 4c9aa98..d71b15b 100644 --- a/src/data_preprocessing/yield_africa_spatial_splits.py +++ b/src/data_preprocessing/yield_africa_spatial_splits.py @@ -79,13 +79,12 @@ def make_spatial_split( """Return a split-indices dict using DBSCAN spatial clustering. :param df: full model-ready dataframe (must contain 'lat', 'lon', 'name_loc') - :param distance_m: DBSCAN eps in metres — pairs of fields closer than this - value are assigned to the same cluster + :param distance_m: DBSCAN eps in metres — pairs of fields closer than this value are assigned + to the same cluster :param train_val_test_split: (train, val, test) proportions, must sum to 1.0 :param seed: random seed for GroupShuffleSplit - :return: dict with 'train_indices', 'val_indices', 'test_indices' as - pd.Series of name_loc strings, plus 'clusters' as a numpy array of - cluster labels (same length as df) + :return: dict with 'train_indices', 'val_indices', 'test_indices' as pd.Series of name_loc + strings, plus 'clusters' as a numpy array of cluster labels (same length as df) """ # Deduplicate to unique (lat, lon) locations before clustering. # yield_africa has ~9 rows per location (one per year); running DBSCAN on all @@ -271,9 +270,7 @@ def generate_splits( f"(train={n_train}, val={n_val}, test={n_test}, " f"total={n_train + n_val + n_test}/{len(df)})" ) - log.info( - f" {dist_km}km: train={n_train}, val={n_val}, test={n_test} -> {out_name}" - ) + log.info(f" {dist_km}km: train={n_train}, val={n_val}, test={n_test} -> {out_name}") def main() -> None: diff --git a/src/models/components/loss_fns/rrmse_loss.py b/src/models/components/loss_fns/rrmse_loss.py index 1720f1b..8c62f7e 100644 --- a/src/models/components/loss_fns/rrmse_loss.py +++ b/src/models/components/loss_fns/rrmse_loss.py @@ -8,15 +8,13 @@ class RRMSELoss(BaseLossFn): """Relative Root Mean Squared Error (RRMSE). - RRMSE = RMSE / mean(|labels|) + RRMSE = RMSE / mean(labels) - Normalises RMSE by the mean absolute value of the target, giving a - unit-free percentage error. This makes results comparable across crops - and regions with different absolute yield scales (e.g. t/ha ranges - differ significantly between maize in Zambia and rice in Rwanda). + Normalises RMSE by the mean absolute value of the target, giving a unit-free percentage error. + This makes results comparable across crops and regions with different absolute yield scales + (e.g. t/ha ranges differ significantly between maize in Zambia and rice in Rwanda). - Returns a fraction (e.g. 0.15 = 15 % error). Multiply by 100 for - percentage when reporting. + Returns a fraction (e.g. 0.15 = 15 % error). Multiply by 100 for percentage when reporting. """ def __init__(self) -> None: From 6fdc4742cf05b53c457ed4339ceaa304ce3eb3a5 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:09:57 +0200 Subject: [PATCH 50/78] Fix hugging face dir --- configs/paths/shared.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/paths/shared.yaml b/configs/paths/shared.yaml index a4702b0..44218a8 100644 --- a/configs/paths/shared.yaml +++ b/configs/paths/shared.yaml @@ -21,4 +21,4 @@ work_dir: ${hydra:runtime.cwd} # huggingface cache directory # can be overridden via HF_HOME environment variable -huggingface_cache: ${oc.env:HF_HOME,oc.env:SHARED_CACHE/huggingface} +huggingface_cache: ${oc.env:HF_HOME,${oc.env:SHARED_CACHE}/huggingface} From 79800e92a63c9b206a37213c4794457b22325069 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:11:28 +0200 Subject: [PATCH 51/78] Prediction head initialisation print statements --- src/models/components/geo_encoders/mlp_projector.py | 2 +- src/models/components/pred_heads/linear_pred_head.py | 1 + src/models/components/pred_heads/mlp_pred_head.py | 1 + src/models/components/pred_heads/mlp_regression_head.py | 2 ++ 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/models/components/geo_encoders/mlp_projector.py b/src/models/components/geo_encoders/mlp_projector.py index 6c27345..4ab6e05 100644 --- a/src/models/components/geo_encoders/mlp_projector.py +++ b/src/models/components/geo_encoders/mlp_projector.py @@ -27,7 +27,7 @@ def __init__( @override def setup(self) -> List[str]: self.configure_nn() - print("Model setup with MLP projector") + print("Model set up with MLP projector") return ["net"] def set_input_dim(self, input_dim: int) -> None: diff --git a/src/models/components/pred_heads/linear_pred_head.py b/src/models/components/pred_heads/linear_pred_head.py index 94338a7..1dfa094 100644 --- a/src/models/components/pred_heads/linear_pred_head.py +++ b/src/models/components/pred_heads/linear_pred_head.py @@ -37,6 +37,7 @@ def setup(self) -> None: assert type(self.input_dim) is int, self.input_dim assert type(self.output_dim) is int, self.output_dim self.net = nn.Linear(self.input_dim, self.output_dim) + print("Model set up with linear prediction head") return diff --git a/src/models/components/pred_heads/mlp_pred_head.py b/src/models/components/pred_heads/mlp_pred_head.py index 6d4c124..65610a5 100644 --- a/src/models/components/pred_heads/mlp_pred_head.py +++ b/src/models/components/pred_heads/mlp_pred_head.py @@ -50,6 +50,7 @@ def setup(self) -> None: input_dim = self.hidden_dim layers.append(nn.Linear(input_dim, self.output_dim)) self.net = nn.Sequential(*layers) + print("Model set up with MLP prediction head") return diff --git a/src/models/components/pred_heads/mlp_regression_head.py b/src/models/components/pred_heads/mlp_regression_head.py index d9553ec..9679fa2 100644 --- a/src/models/components/pred_heads/mlp_regression_head.py +++ b/src/models/components/pred_heads/mlp_regression_head.py @@ -70,3 +70,5 @@ def setup(self) -> None: layers.append(nn.Linear(in_dim, self.output_dim)) self.net = nn.Sequential(*layers) + print("Model set up with MLP regression prediction head") + return From 2f56cc9cd19605c4433f4441c53b3519b26d759c Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:11:45 +0200 Subject: [PATCH 52/78] Add full freezer --- src/models/base_model.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/models/base_model.py b/src/models/base_model.py index 89fbaf6..7cb7296 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -42,6 +42,18 @@ def setup(self, stage: str) -> None: called after trainer is initialized and datamodule is available.""" pass + @final + def full_freezer(self): + """Freeze the whole network.""" + # Freeze the whole network + for name, param in self.named_parameters(): + param.requires_grad = False + + for name, module in self.named_modules(): + module.eval() + + return + @final def freezer(self) -> None: """Freezes modules based on provided trainable modules.""" From 047c7f66b0717ebbfd208e161f0282a9010cb753 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:12:53 +0200 Subject: [PATCH 53/78] Add configuration saving into ckpg This is needed for the inference model re-configuration --- src/models/base_model.py | 37 +++++++++++++++++++---- src/models/predictive_model.py | 26 ++++++++++++---- src/models/text_alignment_model.py | 48 +++++++++++++++--------------- 3 files changed, 75 insertions(+), 36 deletions(-) diff --git a/src/models/base_model.py b/src/models/base_model.py index 7cb7296..1a6a2b5 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -16,6 +16,8 @@ def __init__( scheduler: torch.optim.lr_scheduler, loss_fn: BaseLossFn, metrics: MetricsWrapper, + num_classes: int | None = None, + tabular_dim: int | None = None, ) -> None: """Interface for any model. @@ -24,17 +26,29 @@ def __init__( :param scheduler: scheduler for the model weight update :param loss_fn: loss function :param metrics: metrics to track for model performance estimation + :param num_classes: number of target classes + :param tabular_dim: number of tabular features """ super().__init__() self.save_hyperparameters( - ignore=["loss_fn", "geo_encoder", "prediction_head", "text_encoder", "metrics"] + ignore=[ + "loss_fn", + "geo_encoder", + "prediction_head", + "text_encoder", + "metrics", + "optimizer", + "scheduler", + ] ) self.trainable_modules = trainable_modules - self.num_classes: int | None = None - self.tabular_dim: int | None = None + self.num_classes = num_classes + self.tabular_dim = tabular_dim self.loss_fn = loss_fn self.metrics = metrics + self.optimizer = optimizer + self.scheduler = scheduler @abstractmethod def setup(self, stage: str) -> None: @@ -139,10 +153,10 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Ten def configure_optimizers(self) -> Dict[str, Any]: """Configure optimizer and learning rate scheduler.""" - optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + optimizer = self.optimizer(params=self.trainer.model.parameters()) - if self.hparams.scheduler is not None: - scheduler = self.hparams.scheduler(optimizer=optimizer) + if self.scheduler is not None: + scheduler = self.scheduler(optimizer=optimizer) return { "optimizer": optimizer, "lr_scheduler": { @@ -162,6 +176,14 @@ def on_save_checkpoint(self, checkpoint): if any(k.startswith(part) for part in self.trainable_modules) } + checkpoint["hyper_parameters"].update( + { + "num_classes": self.num_classes, + "tabular_dim": self.tabular_dim, + "trainable_modules": self.trainable_modules, + } + ) + def on_load_checkpoint(self, checkpoint): """Load only trainable parts of the model.""" missing_keys, unexpected_keys = self.load_state_dict( @@ -169,6 +191,9 @@ def on_load_checkpoint(self, checkpoint): ) print("Model loaded from a checkpoint.") + # self.tabular_dim = checkpoint['hyper_parameters']["tabular_dim"] + # self.num_classes = checkpoint["hyper_parameters"]["num_classes"] + if missing_keys: missing_keys = {".".join(i.split(".")[:3]) for i in missing_keys} print(f"The following keys are missing from the pretrained model: {missing_keys}") diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index 7b951c2..6e5cb90 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -25,6 +25,8 @@ def __init__( loss_fn: BaseLossFn, metrics: MetricsWrapper, normalize_features: bool = True, + num_classes: int | None = None, + tabular_dim: int | None = None, ) -> None: """Implementation of the predictive model with replaceable GEO encoder, and prediction head. @@ -36,13 +38,15 @@ def __init__( :param scheduler: scheduler to use for training :param loss_fn: loss function to use :param metrics: metrics to use for model performance evaluation - :param num_classes: number of target classes - :param tabular_dim: number of tabular features :param normalize_features: if True, apply L2 normalisation to encoder output before the prediction head (default: True) + :param num_classes: number of target classes + :param tabular_dim: number of tabular features """ - super().__init__(trainable_modules, optimizer, scheduler, loss_fn, metrics) + super().__init__( + trainable_modules, optimizer, scheduler, loss_fn, metrics, num_classes, tabular_dim + ) # Geo encoder configuration self.geo_encoder = geo_encoder @@ -55,8 +59,15 @@ def __init__( @override def setup(self, stage: str) -> None: - self.num_classes = self.trainer.datamodule.num_classes - self.tabular_dim = self.trainer.datamodule.tabular_dim + """Updates model based data-bound configurations (through datamodule), This method is + called after trainer is initialized and datamodule is available. + + Otherwise, some configuration variables must be made available + """ + + if self._trainer is not None: + self.num_classes = self.trainer.datamodule.num_classes + self.tabular_dim = self.trainer.datamodule.tabular_dim if stage != "fit": if isinstance(self.trainable_modules, tuple): @@ -67,7 +78,10 @@ def setup(self, stage: str) -> None: print("------------------------") # Freezing requested parts - self.freezer() + if stage in ["inference"]: + self.full_freezer() + else: + self.freezer() def setup_encoders_adapters(self): """Set up encoders and missing adapters/projectors.""" diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 2cc1048..71737a7 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -1,4 +1,3 @@ -from io import text_encoding from typing import Dict, Tuple, override import torch @@ -32,6 +31,8 @@ def __init__( prediction_head: BasePredictionHead | None = None, ks: list[int] | None = [5, 10, 15], match_to_geo: bool = True, + num_classes: int | None = None, + tabular_dim: int | None = None, ) -> None: """Implementation of contrastive text-eo modality alignment model. @@ -42,14 +43,16 @@ def __init__( :param loss_fn: loss function to use (contrastive) :param trainable_modules: list of modules to train (parts/modules or modules, modules) :param metrics: metrics to use for model performance evaluation - :param num_classes: number of target classes - :param tabular_dim: number of tabular features :param prediction_head: prediction head :param ks: list of ks :param match_to_geo: whether to match dimensions of text encoder to geo_encoder or visa- versa + :param num_classes: number of target classes + :param tabular_dim: number of tabular features """ - super().__init__(trainable_modules, optimizer, scheduler, loss_fn, metrics) + super().__init__( + trainable_modules, optimizer, scheduler, loss_fn, metrics, num_classes, tabular_dim + ) # Metrics self.ks = ks @@ -64,31 +67,33 @@ def __init__( self.prediction_head = prediction_head @override - def setup(self, stage: str) -> None: - self.num_classes = self.trainer.datamodule.num_classes - self.tabular_dim = self.trainer.datamodule.tabular_dim + def setup(self, stage: str = "fit") -> None: + """Updates model based data-bound configurations (through datamodule), This method is + called after trainer is initialized and datamodule is available. + + Otherwise, some configuration variables must be made available + """ + + if self._trainer is not None: + self.num_classes = self.trainer.datamodule.num_classes + self.tabular_dim = self.trainer.datamodule.tabular_dim # Set up encoders and missing adapters/projectors print("-------Model------------") self.setup_encoders_adapters() print("------------------------") - # Freeze requested parts - self.freezer() + # Freeze not requested parts + if stage in ["inference"]: + self.full_freezer() + else: + self.freezer() - # Configure contrastive retrieval evaluation - self.setup_retrieval_evaluation() + # Configure contrastive retrieval evaluation + self.setup_retrieval_evaluation() def setup_encoders_adapters(self): """Set up encoders and missing adapters/projectors.""" - # We don't use tabular encoders for wrapping - # if ( - # isinstance(self.geo_encoder, MultiModalEncoder) - # and self.geo_encoder.use_tabular - # and not self.geo_encoder._tabular_ready - # ): - # self.geo_encoder.build_tabular_branch(self.tabular_dim) - # Setup encoders that need data-depended configurations new_modules = [f"geo_encoder.{i}" for i in self.geo_encoder.setup()] self.trainable_modules.extend(new_modules) @@ -109,11 +114,6 @@ def setup_encoders_adapters(self): ) self.prediction_head.setup() - # # Unify dtypes -> moving to data part, rather than changing parameter type - # if self.geo_encoder.dtype != self.text_encoder.dtype: - # self.geo_encoder = self.geo_encoder.to(self.text_encoder.dtype) - # print(f"Geo encoder dtype changed to {self.geo_encoder.dtype}") - def setup_retrieval_evaluation(self): self.concept_configs = self.trainer.datamodule.concept_configs self.concepts = [c["concept_caption"] for c in self.concept_configs] From 32ea0d943b54952fd454b8b0bb5fb1a2e27a46ef Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:13:13 +0200 Subject: [PATCH 54/78] Introduce inference model --- configs/inference.yaml | 25 ++++ src/inference.py | 46 ++++++ src/models/inference_model.py | 262 ++++++++++++++++++++++++++++++++++ 3 files changed, 333 insertions(+) create mode 100644 configs/inference.yaml create mode 100644 src/inference.py create mode 100644 src/models/inference_model.py diff --git a/configs/inference.yaml b/configs/inference.yaml new file mode 100644 index 0000000..45f5bcc --- /dev/null +++ b/configs/inference.yaml @@ -0,0 +1,25 @@ +defaults: + - _self_ + - paths: ${oc.env:STORAGE_MODE,local} + - extras: default + - hydra: default + +# task name, determines output directory path +task_name: "inference" +tags: ["dev"] + +# If set, inference.py will load this merged checkpoint directly. +#inference_ckpt_path: ${paths.log_dir}${task_name}/2026-04-07_12-56-16/inference_model.ckpt + +# If `inference_ckpt_path` is not set, stitch the inference model from: +# - predictive ckpt (provides prediction_head weights) +# - alignment ckpt (provides text_encoder weights + geo_encoder) +predictive_ckpt_path: ${paths.log_dir}train/runs/2026-04-02_15-54-53/checkpoints/epoch_000.ckpt +alignment_ckpt_path: ${paths.log_dir}train/runs/2026-04-02_15-40-03/checkpoints/epoch_000.ckpt + +# If set, inference.py will save a merged inference checkpoint you can reload with +# `inference_ckpt_path`. +save_inference_ckpt_path: ${paths.log_dir}${task_name}/${now:%Y-%m-%d}_${now:%H-%M-%S}/inference_model.ckpt +#save_inference_ckpt_path: null + +training_order: ["alignment_model", "prediction_model"] diff --git a/src/inference.py b/src/inference.py new file mode 100644 index 0000000..dc15b50 --- /dev/null +++ b/src/inference.py @@ -0,0 +1,46 @@ +from typing import Optional + +import hydra +import rootutils +from dotenv import load_dotenv +from omegaconf import DictConfig + +from src.models.inference_model import load_inference_model, merge_inference_model +from src.utils import extras + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +load_dotenv() + +# Disable tokenizers parallelism to avoid warnings when using multiprocessing +import os + +if os.environ.get("TOKENIZERS_PARALLELISM") is None: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="inference.yaml") +def main(cfg: DictConfig) -> Optional[float]: + """Main entry point for training. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Optional[float] with optimized metric value. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + # If a merged inference ckpt is provided, just load it. + inference_ckpt_path = cfg.get("inference_ckpt_path") + if inference_ckpt_path: + model = load_inference_model(inference_ckpt_path) + # Otherwise merge model from two checkpoints + else: + model = merge_inference_model(cfg, save_ckpt=True) + + # TODO: do what you need with the inference model + + return + + +if __name__ == "__main__": + main() diff --git a/src/models/inference_model.py b/src/models/inference_model.py new file mode 100644 index 0000000..93447be --- /dev/null +++ b/src/models/inference_model.py @@ -0,0 +1,262 @@ +import os +from typing import Dict, Tuple, override + +import hydra +import torch +import torch.nn.functional as F + +from src.models.base_model import BaseModel +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder +from src.models.components.pred_heads.linear_pred_head import BasePredictionHead +from src.models.components.text_encoders.base_text_encoder import ( + BaseTextEncoder, +) +from src.utils import RankedLogger +from src.utils.errors import FileNotSpecified + +log = RankedLogger(__name__, rank_zero_only=True) + + +class InferenceModel(BaseModel): + def __init__( + self, + geo_encoder: BaseGeoEncoder, + text_encoder: BaseTextEncoder, + prediction_head: BasePredictionHead, + num_classes: int | None = None, + tabular_dim: int | None = None, + match_to_geo: bool = True, + **kwargs, + ) -> None: + super().__init__( + [], None, None, None, None, num_classes=num_classes, tabular_dim=tabular_dim + ) + + # Encoders configuration + self.geo_encoder = geo_encoder + self.text_encoder = text_encoder + + self.match_to_geo = match_to_geo + + # Prediction head + self.prediction_head = prediction_head + + @override + def setup(self, stage: str) -> None: + # During inference we need to ensure: + # - geo_encoder is fully initialized (sets geo_encoder.output_dim) + # - text/geo dims match (possibly via text_encoder.extra_projector) + # - prediction_head.net is created with correct (input_dim, output_dim) + if stage != "inference": + return + + # Configure geo encoder and its output_dim only if it wasn't already set up. + # (In the normal "stitch-from-ckpts" flow, the modules are already initialized.) + if getattr(self.geo_encoder, "output_dim", None) is None or ( + hasattr(self.geo_encoder, "geo_encoder") + and getattr(self.geo_encoder, "geo_encoder") is None + ): + self.geo_encoder.setup() + + # Configure optional extra projection so text embeddings match geo embeddings. + # Note: current codebase applies extra projector on text encoders (not on geo encoders) + # during forward, so `match_to_geo` is expected to be True. + if self.text_encoder.output_dim != self.geo_encoder.output_dim: + if not self.match_to_geo: + raise ValueError( + "match_to_geo=False is not supported for inference: geo extra projector " + "is not applied in geo encoder forward passes in this codebase." + ) + # If extra_projector already exists but output dims still mismatch, we recreate it. + # Otherwise, avoid overwriting weights unnecessarily. + if ( + getattr(self.text_encoder, "extra_projector", None) is None + or self.text_encoder.output_dim != self.geo_encoder.output_dim + ): + self.text_encoder.add_projector(projected_dim=self.geo_encoder.output_dim) + + # Configure prediction head (it creates `prediction_head.net` in setup()). + if self.prediction_head.net is None: + if self.num_classes is None: + raise ValueError( + "InferenceModel requires `num_classes` to build the prediction head." + ) + self.prediction_head.set_dim( + input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes + ) + self.prediction_head.setup() + + # Freeze everything for pure inference. + self.full_freezer() + + @override + def _step( + self, + batch: Dict[str, torch.Tensor], + mode: str = "train", + ) -> torch.Tensor: + """Step forward computation of the model.""" + pass + + @override + def forward( + self, + batch: Dict[str, torch.Tensor], + mode: str = "train", + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Model forward logic.""" + + # Embed modalities + geo_feats = self.geo_encoder(batch) + text_feats = self.text_encoder(batch, mode) + pred_feats = self.prediction_head(geo_feats) + + # Change dtype of geo data if it does not match text dtype + if geo_feats.dtype != text_feats.dtype: + geo_feats = geo_feats.to(text_feats.dtype) + return pred_feats, geo_feats, text_feats + + def concept_similarities(self, geo_embeds, concept=None) -> torch.Tensor: + # Get concept embeddings + if concept is not None: + # If only one concept is provided + if isinstance(concept, str): + concept = [concept] + with torch.no_grad(): + concept_embeds = self.text_encoder({"text": concept}, mode="train") + + elif self.concept_embeds is not None: + concept_embeds = self.concept_embeds + else: + with torch.no_grad(): + concept_embeds = self.text_encoder({"text": self.concepts}, mode="train") + + # Similarity + geo_embeds = F.normalize(geo_embeds, dim=1) + concept_embeds = F.normalize(concept_embeds, dim=1) + similarity_matrix = concept_embeds @ geo_embeds.T + + return similarity_matrix + + +def _is_prefix_trained(trainable_modules: list[str], prefix: str) -> bool: + """True if any trainable module starts with `prefix` (before dot).""" + return any(m.split(".")[0] == prefix for m in trainable_modules) + + +def _group_keys(keys: list[str]) -> dict[str, list[str]]: + """Groups module names (keys)""" + grouped: dict[str, list[str]] = {} + for k in keys: + top = k.split(".", 1)[0] if "." in k else k + grouped.setdefault(top, []).append(k) + return grouped + + +def _log_load_result(tag: str, result) -> None: + """Log missing/unexpected keys from `load_state_dict`.""" + missing_keys, unexpected_keys = result + if missing_keys: + grouped = _group_keys(list(missing_keys)) + summary = {k: len(v) for k, v in grouped.items()} + log.warning(f"[{tag}] Missing keys: {summary}") + if unexpected_keys: + grouped = _group_keys(list(unexpected_keys)) + summary = {k: len(v) for k, v in grouped.items()} + log.warning(f"[{tag}] Unexpected keys: {summary}") + + +def load_inference_model(inference_ckpt_path: str) -> InferenceModel: + """Loads inference model from a merged checkpoint. + + :param inference_ckpt_path: path to inference model weights + :return: an InferenceModel with pre-trained weights + """ + inference_ckpt = torch.load(inference_ckpt_path, map_location="cpu", weights_only=False) + model = hydra.utils.instantiate(inference_ckpt["hyper_parameters"]) + model.setup("inference") + res = model.load_state_dict(inference_ckpt["state_dict"], strict=False) + _log_load_result("inference_ckpt", res) + return model + + +def merge_inference_model(cfg, save_ckpt=False) -> InferenceModel | None: + """Configures the inference model from the predictive + alignment checkpoints. + + :param cfg: A DictConfig configuration composed by Hydra. + :param save_ckpt: Whether to save the model or not. + :return: an InferenceModel with pre-trained weights + """ + + # Stitch the inference model from the predictive + alignment checkpoints. + pred_ckpt_path = cfg.get("predictive_ckpt_path") or FileNotSpecified( + 'You must specify predictive model weight path as "predictive_ckpt_path"' + ) + align_ckpt_path = cfg.get("alignment_ckpt_path") or FileNotSpecified( + 'You must specify alignment model weight path as "alignment_ckpt_path"' + ) + # TODO: remove dataset saving into the checkpoint + pred_ckpt = torch.load(pred_ckpt_path, map_location="cpu", weights_only=False) + align_ckpt = torch.load(align_ckpt_path, map_location="cpu", weights_only=False) + + # Sanity check: ensure geo encoder configs match. + if pred_ckpt["hyper_parameters"].get("geo_encoder") != align_ckpt["hyper_parameters"].get( + "geo_encoder" + ): + log.warning("Geo encoder configs differ between checkpoints; results may be invalid.") + if input("Do you want to proceed? y/n").lower() == "n": + return None + + pred_trainable_modules = [ + pred_ckpt["hyper_parameters"].get("trainable_modules", [])[0] + ] # TODO fix + align_trainable_modules = align_ckpt["hyper_parameters"].get("trainable_modules", []) + + geo_pred_encoder_trained = _is_prefix_trained(pred_trainable_modules, "geo_encoder") + geo_align_encoder_trained = _is_prefix_trained(align_trainable_modules, "geo_encoder") + + if geo_pred_encoder_trained and geo_align_encoder_trained: + raise ValueError("Models are not aligned: both checkpoints trained geo_encoder.") + + # Instantiate InferenceModel via hydra, using alignment encoder configs with prediction model head configs + inference_hparams = align_ckpt["hyper_parameters"] + inference_hparams.update( + { + "_target_": "src.models.inference_model.InferenceModel", + "prediction_head": pred_ckpt["hyper_parameters"].get("prediction_head"), + "num_classes": pred_ckpt["hyper_parameters"].get("num_classes"), + } + ) + inference_hparams["text_encoder"]["hf_cache_dir"] = os.path.join( + cfg.paths.cache_dir, "huggingface" + ) + + model: InferenceModel = hydra.utils.instantiate(inference_hparams) + model.setup("inference") + + # Load alignment weights first (text encoder + geo encoder). + res = model.load_state_dict(align_ckpt["state_dict"], strict=False) + _log_load_result("alignment_ckpt", res) + + # Load prediction head weights from predictive ckpt. + head_state = { + k: v for k, v in pred_ckpt["state_dict"].items() if k.startswith("prediction_head.") + } + res = model.load_state_dict(head_state, strict=False) + _log_load_result("predictive_prediction_head_only", res) + + # Save model + if save_ckpt: + save_path = cfg.get("save_inference_ckpt_path") + if not save_path: + print("Model could not be saved as save_path was not provided") + + # Get `state_dict` + state_dict = model.state_dict() + + # Save + os.makedirs(os.path.dirname(save_path), exist_ok=True) + torch.save({"state_dict": state_dict, "hyper_parameters": inference_hparams}, save_path) + log.info(f"Saved merged inference checkpoint to: {save_path}") + + return model From 883df2e68d7d7bbe83b92597cc671ce8751ea6f5 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:13:37 +0200 Subject: [PATCH 55/78] Add configuration saving into ckpg --- src/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/train.py b/src/train.py index 8348fbb..7a4992c 100644 --- a/src/train.py +++ b/src/train.py @@ -6,7 +6,7 @@ from dotenv import load_dotenv from lightning import Callback, LightningModule, Trainer from lightning.pytorch.loggers import Logger -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from src.data.base_datamodule import BaseDataModule @@ -53,6 +53,10 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: log.info(f"Instantiating model <{cfg.model._target_}>") model: LightningModule = hydra.utils.instantiate(cfg.model) + # Append model hparams from config to be saved in ckpg + raw_model_cfg = OmegaConf.to_container(cfg.model, resolve=True) + model.hparams.update(raw_model_cfg) + log.info("Instantiating callbacks...") callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) From f86df5c915d2ff8db05bead81dd0630ccb4bbd8c Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 13:17:32 +0200 Subject: [PATCH 56/78] tessera fixes --- src/data_preprocessing/tessera_embeds.py | 86 ++++++++++++++++++------ 1 file changed, 64 insertions(+), 22 deletions(-) diff --git a/src/data_preprocessing/tessera_embeds.py b/src/data_preprocessing/tessera_embeds.py index c2414b8..1e1830d 100644 --- a/src/data_preprocessing/tessera_embeds.py +++ b/src/data_preprocessing/tessera_embeds.py @@ -1,3 +1,4 @@ +import concurrent.futures import math import os import threading @@ -23,6 +24,18 @@ ) +class PartialTileError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class NoTileError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + def reproject_dataset(src_raster: MemoryFile, dst_crs: str) -> MemoryFile: """Reprojects Memory file if it's not in dst_crs. @@ -70,6 +83,7 @@ def get_tessera_embeds( save_dir: str, tile_size: int, tessera_con: GeoTessera | None, + padding: int = 100, ) -> None: """Obtain tessera embedding tile with specified size for a given coordinates. @@ -80,9 +94,11 @@ def get_tessera_embeds( :param save_dir: data directory to save embeddings :param tile_size: tile size in pixels :param tessera_con: GeoTessera instance + :param padding: how many meters to pad initial bbox, fixes some inconsistencies when mosaicing :return: None """ + # Skip if tile exists embed_tile_name = os.path.join(save_dir, f"tessera_{name_loc}.npy") if os.path.exists(embed_tile_name): return @@ -92,7 +108,7 @@ def get_tessera_embeds( lon_utm, lat_utm = point_reprojection(lon, lat, "EPSG:4326", utm_crs) # Bounding box - radius = math.ceil(tile_size / 2) + 10 + radius = math.ceil(tile_size / 2) + padding bbox = create_bbox_with_radius(lon, lat, radius=radius, utm_crs=utm_crs, return_wgs=True) # Request to tessera @@ -126,13 +142,10 @@ def get_tessera_embeds( if reproject_memfile: memfiles.append(reproject_memfile) - if not tiles: - print( - f"No TESSERA tiles found for {name_loc} at ({lon:.4f}, {lat:.4f}) year={year}. Skipping." - ) + if len(tiles) == 0: for mf in memfiles: mf.close() - return + raise NoTileError(f"No tiles found for {name_loc}") # if no tiles, add to skipped.txt mosaic, mosaic_transform = merge(tiles) mosaic = mosaic.transpose(1, 2, 0) @@ -143,20 +156,29 @@ def get_tessera_embeds( mf.close() # Crop patch tile - col, row = crs_to_pixel_coords(lon_utm, lat_utm, mosaic_transform) + c, r = crs_to_pixel_coords(lon_utm, lat_utm, mosaic_transform) half = tile_size // 2 - row_min = row - half - row_max = row + tile_size - half # tile_size - half ensures correct size for odd tile_size - col_min = col - half - col_max = col + tile_size - half + row_min = r - half + row_max = r + half + col_min = c - half + col_max = c + half + + if row_min < 0 or row_max < 0 or col_min < 0 or col_max < 0: + # retry with bigger padding + if padding > 500: + raise NoTileError(f"Padding {padding} > 500") + get_tessera_embeds( + lon, lat, name_loc, year, save_dir, tile_size, tessera_con, padding=padding + 100 + ) + crop = mosaic[row_min:row_max, col_min:col_max, :] + if not crop.shape == (tile_size, tile_size, 128): + if crop.min() == 0.0 and crop.max() == 0.0: + raise NoTileError(f"No tiles found for {name_loc}") + raise PartialTileError(f"Crop {name_loc}, size is {crop.shape}") - if crop.shape[0] != tile_size or crop.shape[1] != tile_size: - print( - f"Unexpected crop shape {crop.shape} for {name_loc} " - f"(expected {tile_size}x{tile_size}). Skipping." - ) - return + if crop.min() == 0.0 and crop.max() == 0.0: + raise NoTileError(f"Crop {name_loc} has embeddings of 0.0s with tiles: {tiles_to_fetch}") # Save array os.makedirs(save_dir, exist_ok=True) @@ -186,6 +208,7 @@ def tessera_from_df( year: int, tile_size: int = 256, cache_dir: str = "temp/", + logs_dir: str = "logs", ) -> None: """Obtains Tessera embeddings from a CSV file for each (lon, lat). @@ -205,8 +228,17 @@ def tessera_from_df( n = len(model_ready_df) for i, row in model_ready_df.iterrows(): print(f"{i}/{n}") - # Get tessera embeds - get_tessera_embeds(row.lon, row.lat, row.name_loc, year, f"{data_dir}/", tile_size, gt) + if i < 238 or i in [1319]: + continue + try: + get_tessera_embeds(row.lon, row.lat, row.name_loc, year, f"{data_dir}/", tile_size, gt) + except Exception as e: + if isinstance(e, NoTileError): + path = os.path.join(logs_dir, "tessera_skipped.txt") + with open(path, "a") as f: + f.write(f"{row.name_loc}\n") + else: + print(f"{row.name_loc} did not get embedded because: {e}") def inspect_np_arr_as_tiff( @@ -270,10 +302,20 @@ def inspect_np_arr_as_tiff( if __name__ == "__main__": - os.chdir("../..") + print(os.getcwd()) + + # df = pd.read_csv("data/heat_guatemala/model_ready_heat_guatemala.csv") + df = pd.read_csv("/lustre/backup/SHARED/AIN/aether/data/s2bms/model_ready_s2bms.csv") + # df.sort_values(by="name_loc", inplace=True, ascending=False) + with open(os.path.join("logs", "tessera_skipped.txt")) as f: + skipped = set(f.read().splitlines()) - df = pd.read_csv("data/heat_guatemala/model_ready_heat_guatemala.csv") + df = df[~df.name_loc.isin(skipped)] tessera_from_df( - df, "data/heat_guatemala/eo/tessera_2024", year=2024, tile_size=10, cache_dir="data/cache" + df, + "/lustre/backup/SHARED/AIN/aether/data/s2bms/eo/tessera", + year=2024, + tile_size=256, + cache_dir="/lustre/backup/SHARED/AIN/aether/data/cache", ) From 19c3fcf2079caf92954616f7061933c626fa26f1 Mon Sep 17 00:00:00 2001 From: Thijs van der Plas Date: Tue, 31 Mar 2026 22:00:16 +0200 Subject: [PATCH 57/78] Index dynamic top k -- not tested yet --- src/models/text_alignment_model.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 71737a7..2dd4a2c 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -1,3 +1,4 @@ +from io import text_encoding from typing import Dict, Tuple, override import torch @@ -122,6 +123,29 @@ def setup_retrieval_evaluation(self): for c in self.concept_configs ] + dataset_names = ["train", "val", "test"] + self.dynamic_k_baselines = {} + for dataset_name in dataset_names: + if not hasattr(self.trainer.datamodule, f"{dataset_name}_dataloader"): + continue + tmp_ds = getattr(self.trainer.datamodule, f"{dataset_name}_dataloader")().dataset + n_ds = len(tmp_ds) + self.dynamic_k_baselines[dataset_name] = {} + for i_c, c in enumerate(self.concept_configs): + c_name = self.concept_names[i_c] + aux_col_id = c["id"] + if c["is_max"]: + n_baseline = sum( + tmp_ds[ii]["aux"][aux_col_id] >= c.get("theta_k", float("inf")) + for ii in range(len(tmp_ds)) + ) + else: + n_baseline = sum( + tmp_ds[ii]["aux"][aux_col_id] <= c.get("theta_k", float("inf")) + for ii in range(len(tmp_ds)) + ) + self.dynamic_k_baselines[dataset_name][c_name] = n_baseline / n_ds * 100 + self.contrastive_val = RetrievalContrastiveValidation(self.ks, self.concept_configs) self.outputs_epoch_memory = [] @@ -219,7 +243,11 @@ def _on_epoch_end(self, mode: str): print(f'\nConcept "{self.concepts[i]}" average top-k accuracies in {mode} split:') for k, v in result.items(): if k == "dynamic_k": - self.log(f"dyn_k_{self.concept_names[i]}", v, **self.log_kwargs) + self.log(f"dyn_k_{self.self.concept_names[i]}", v, **self.log_kwargs) + indexed_v = (v - self.dynamic_k_baselines[mode][self.concept_names[i]]) / ( + 100 - self.dynamic_k_baselines[mode][self.concept_names[i]] + ) + self.log(f"dyn_k_index_{self.concept_names[i]}", indexed_v, **self.log_kwargs) print(f"Top-{k}: {v:.1f}%") avr_scores[f"{mode}_avr_top-{k}"].append(v) From fa128b12b7447502c25d729f062957cc7805c30d Mon Sep 17 00:00:00 2001 From: Thijs van der Plas Date: Wed, 1 Apr 2026 08:47:54 +0200 Subject: [PATCH 58/78] Elbow method for theta k -- not tested yet --- src/models/text_alignment_model.py | 33 ++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 2dd4a2c..d713a38 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -1,6 +1,7 @@ from io import text_encoding from typing import Dict, Tuple, override +import numpy as np import torch import torch.nn.functional as F @@ -134,16 +135,13 @@ def setup_retrieval_evaluation(self): for i_c, c in enumerate(self.concept_configs): c_name = self.concept_names[i_c] aux_col_id = c["id"] + aux_vals_current_ds = [tmp_ds[ii]["aux"][aux_col_id] for ii in range(len(tmp_ds))] + # theta_k = c['theta_k'] + theta_k = self.find_elbow_point(aux_vals_current_ds) if c["is_max"]: - n_baseline = sum( - tmp_ds[ii]["aux"][aux_col_id] >= c.get("theta_k", float("inf")) - for ii in range(len(tmp_ds)) - ) + n_baseline = sum(aux_val >= theta_k for aux_val in aux_vals_current_ds) else: - n_baseline = sum( - tmp_ds[ii]["aux"][aux_col_id] <= c.get("theta_k", float("inf")) - for ii in range(len(tmp_ds)) - ) + n_baseline = sum(aux_val <= theta_k for aux_val in aux_vals_current_ds) self.dynamic_k_baselines[dataset_name][c_name] = n_baseline / n_ds * 100 self.contrastive_val = RetrievalContrastiveValidation(self.ks, self.concept_configs) @@ -288,3 +286,22 @@ def concept_similarities(self, geo_embeds, concept=None) -> torch.Tensor: similarity_matrix = concept_embeds @ geo_embeds.T return similarity_matrix + + def find_elbow_point(vals): + vals = np.sort(vals) + x = np.arange(len(vals)) / len(vals) + y = vals + slope = (y[-1] - y[0]) / (x[-1] - x[0]) # diagonal from first to last point + intercept = y[0] - slope * x[0] + orthogonal_slope = -1 / slope + + intercepts_orthogonal = y - orthogonal_slope * x + intersection_diagonal_orthogonal = (intercepts_orthogonal - intercept) / ( + slope - orthogonal_slope + ) + distances = np.sqrt( + (x - intersection_diagonal_orthogonal) ** 2 + (y - (slope * x + intercept)) ** 2 + ) # distance to diagonal + elbow_index = np.argmax(distances) + elbow_point = y[elbow_index] + return elbow_point From bf511842821b534ca181cd26f2ed866ba1145109 Mon Sep 17 00:00:00 2001 From: Thijs van der Plas Date: Wed, 1 Apr 2026 18:48:54 +0200 Subject: [PATCH 59/78] update theta_k with calculated --- src/models/text_alignment_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index d713a38..354c632 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -138,6 +138,9 @@ def setup_retrieval_evaluation(self): aux_vals_current_ds = [tmp_ds[ii]["aux"][aux_col_id] for ii in range(len(tmp_ds))] # theta_k = c['theta_k'] theta_k = self.find_elbow_point(aux_vals_current_ds) + self.concept_configs[i_c][ + "theta_k" + ] = theta_k # assign new theta_k to concept_configs for later use in validation if c["is_max"]: n_baseline = sum(aux_val >= theta_k for aux_val in aux_vals_current_ds) else: From 67b0221d5508a632f80b5ec85b78e20543ba6e35 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 14:58:21 +0200 Subject: [PATCH 60/78] Cleaner print --- src/models/base_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/models/base_model.py b/src/models/base_model.py index 1a6a2b5..1acf647 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -107,11 +107,10 @@ def _in_train_scope(name: str) -> bool: else: module.eval() - print("----------------------------") - print("Set to train") + print("------Set to train------") for m in sorted(trainable): print(f" {m}") - print("----------------------------") + print("------------------------") @abstractmethod def forward( From d4fce8b7a6a17cf527216f3879f47294d537692f Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 14:58:44 +0200 Subject: [PATCH 61/78] Move logging --- src/models/inference_model.py | 29 ++++------------------------- src/utils/logging_utils.py | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/src/models/inference_model.py b/src/models/inference_model.py index 93447be..0fab831 100644 --- a/src/models/inference_model.py +++ b/src/models/inference_model.py @@ -13,6 +13,7 @@ ) from src.utils import RankedLogger from src.utils.errors import FileNotSpecified +from src.utils.logging_utils import log_model_loading log = RankedLogger(__name__, rank_zero_only=True) @@ -144,28 +145,6 @@ def _is_prefix_trained(trainable_modules: list[str], prefix: str) -> bool: return any(m.split(".")[0] == prefix for m in trainable_modules) -def _group_keys(keys: list[str]) -> dict[str, list[str]]: - """Groups module names (keys)""" - grouped: dict[str, list[str]] = {} - for k in keys: - top = k.split(".", 1)[0] if "." in k else k - grouped.setdefault(top, []).append(k) - return grouped - - -def _log_load_result(tag: str, result) -> None: - """Log missing/unexpected keys from `load_state_dict`.""" - missing_keys, unexpected_keys = result - if missing_keys: - grouped = _group_keys(list(missing_keys)) - summary = {k: len(v) for k, v in grouped.items()} - log.warning(f"[{tag}] Missing keys: {summary}") - if unexpected_keys: - grouped = _group_keys(list(unexpected_keys)) - summary = {k: len(v) for k, v in grouped.items()} - log.warning(f"[{tag}] Unexpected keys: {summary}") - - def load_inference_model(inference_ckpt_path: str) -> InferenceModel: """Loads inference model from a merged checkpoint. @@ -176,7 +155,7 @@ def load_inference_model(inference_ckpt_path: str) -> InferenceModel: model = hydra.utils.instantiate(inference_ckpt["hyper_parameters"]) model.setup("inference") res = model.load_state_dict(inference_ckpt["state_dict"], strict=False) - _log_load_result("inference_ckpt", res) + log_model_loading("inference_ckpt", res) return model @@ -236,14 +215,14 @@ def merge_inference_model(cfg, save_ckpt=False) -> InferenceModel | None: # Load alignment weights first (text encoder + geo encoder). res = model.load_state_dict(align_ckpt["state_dict"], strict=False) - _log_load_result("alignment_ckpt", res) + log_model_loading("alignment_ckpt", res) # Load prediction head weights from predictive ckpt. head_state = { k: v for k, v in pred_ckpt["state_dict"].items() if k.startswith("prediction_head.") } res = model.load_state_dict(head_state, strict=False) - _log_load_result("predictive_prediction_head_only", res) + log_model_loading("predictive_prediction_head_only", res) # Save model if save_ckpt: diff --git a/src/utils/logging_utils.py b/src/utils/logging_utils.py index 4895a96..f23f4df 100644 --- a/src/utils/logging_utils.py +++ b/src/utils/logging_utils.py @@ -59,3 +59,27 @@ def log_hyperparameters(object_dict: Dict[str, Any]) -> None: # send hparams to all loggers for logger in trainer.loggers: logger.log_hyperparams(hparams) + + +def _group_keys(keys: list[str]) -> dict[str, list[str]]: + """Groups module names (keys)""" + grouped: dict[str, list[str]] = {} + for k in keys: + top = k.split(".", 1)[0] if "." in k else k + grouped.setdefault(top, []).append(k) + return grouped + + +def log_model_loading(tag: str, result) -> None: + """Log missing/unexpected keys from `load_state_dict`.""" + missing_keys, unexpected_keys = result + if missing_keys: + grouped = _group_keys(list(missing_keys)) + summary = {k: len(v) for k, v in grouped.items()} + log.warning(f"[{tag}] Missing keys: {summary}") + if unexpected_keys: + grouped = _group_keys(list(unexpected_keys)) + summary = {k: len(v) for k, v in grouped.items()} + log.warning(f"[{tag}] Unexpected keys: {summary}") + if not missing_keys and not unexpected_keys: + log.info(f"[{tag}] Module weights loaded successfully") From 0e9c285982cdba466855df9955c6e8ffb993871d Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 14:59:34 +0200 Subject: [PATCH 62/78] Fix trainable module typo, add to freeze all model upon testing --- src/models/predictive_model.py | 4 ++-- src/models/text_alignment_model.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index 6e5cb90..5663afb 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -78,7 +78,7 @@ def setup(self, stage: str) -> None: print("------------------------") # Freezing requested parts - if stage in ["inference"]: + if stage in ["inference", "test"]: self.full_freezer() else: self.freezer() @@ -94,7 +94,7 @@ def setup_encoders_adapters(self): self.geo_encoder.set_tabular_input_dim(self.tabular_dim) # Setup encoders that need data-depended configurations - new_modules = [f"geo_encoder.{i}]" for i in self.geo_encoder.setup()] + new_modules = [f"geo_encoder.{i}" for i in self.geo_encoder.setup() or []] self.trainable_modules.extend(new_modules) # Configure prediction head based on geo-encoder output_dim diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 354c632..01eb7ca 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -86,7 +86,7 @@ def setup(self, stage: str = "fit") -> None: print("------------------------") # Freeze not requested parts - if stage in ["inference"]: + if stage in ["inference", "test"]: self.full_freezer() else: self.freezer() From 331aba2ae98ba32f0023519824f0b0c6dfc9554c Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 15:00:00 +0200 Subject: [PATCH 63/78] Introduce adopted encoder Geo encoder from a checkpoint --- .../components/geo_encoders/adopt_encoder.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 src/models/components/geo_encoders/adopt_encoder.py diff --git a/src/models/components/geo_encoders/adopt_encoder.py b/src/models/components/geo_encoders/adopt_encoder.py new file mode 100644 index 0000000..f6e9f76 --- /dev/null +++ b/src/models/components/geo_encoders/adopt_encoder.py @@ -0,0 +1,37 @@ +from typing import Dict, List, override + +import hydra +import torch + +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder +from src.utils.logging_utils import log_model_loading + + +def adopt_encoder(ckpt_path: str) -> BaseGeoEncoder: + """Return geo_encoder from a provided checkpoint. + + :param ckpt_path: path to checkpoint file + :return: trained geo encoder + """ + # Get checkpoint + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) + + # Get skeleton + geo_config = ckpt["hyper_parameters"].get("geo_encoder") + encoder: BaseGeoEncoder = hydra.utils.instantiate(geo_config) + print("---Adopted encoder------") + encoder.setup() + print("------------------------") + + # Load in the weights + state_dict = { + k.replace("geo_encoder.", ""): v + for k, v in ckpt["state_dict"].items() + if "geo_encoder." in k + } + res = encoder.load_state_dict(state_dict, strict=False) + log_model_loading("geo_encoder_ckpt", res) + + encoder.setup = lambda *args, **kwargs: None # TODO: switch to maybe self.setup flag + + return encoder From 048cef03dd169c97962315bd82069fd769f51cf3 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 15:41:49 +0200 Subject: [PATCH 64/78] optimise iterating through dataset one single time --- src/models/text_alignment_model.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 01eb7ca..ec8fbd5 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -127,16 +127,27 @@ def setup_retrieval_evaluation(self): dataset_names = ["train", "val", "test"] self.dynamic_k_baselines = {} for dataset_name in dataset_names: - if not hasattr(self.trainer.datamodule, f"{dataset_name}_dataloader"): + if not hasattr(self.trainer.datamodule, f"data_{dataset_name}"): continue - tmp_ds = getattr(self.trainer.datamodule, f"{dataset_name}_dataloader")().dataset + + tmp_ds = getattr(self.trainer.datamodule, f"data_{dataset_name}") n_ds = len(tmp_ds) self.dynamic_k_baselines[dataset_name] = {} + + # Placeholder for all concepts + aux_vals_per_concept = {i: [] for i in range(len(self.concept_configs))} + + for item in tmp_ds: + aux_data = item["aux"]["aux"] + for i_c, c in enumerate(self.concept_configs): + aux_col_id = c["id"] + aux_vals_per_concept[i_c].append(aux_data[aux_col_id]) + + # Compute per concept for i_c, c in enumerate(self.concept_configs): c_name = self.concept_names[i_c] - aux_col_id = c["id"] - aux_vals_current_ds = [tmp_ds[ii]["aux"][aux_col_id] for ii in range(len(tmp_ds))] - # theta_k = c['theta_k'] + aux_vals_current_ds = aux_vals_per_concept[i_c] + theta_k = self.find_elbow_point(aux_vals_current_ds) self.concept_configs[i_c][ "theta_k" @@ -290,6 +301,7 @@ def concept_similarities(self, geo_embeds, concept=None) -> torch.Tensor: return similarity_matrix + @staticmethod def find_elbow_point(vals): vals = np.sort(vals) x = np.arange(len(vals)) / len(vals) From 30bdb828a1d39e67bc47ffc1738c184c57e192ef Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 7 Apr 2026 15:57:19 +0200 Subject: [PATCH 65/78] self.self fix --- src/models/text_alignment_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index ec8fbd5..397da7d 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -255,7 +255,7 @@ def _on_epoch_end(self, mode: str): print(f'\nConcept "{self.concepts[i]}" average top-k accuracies in {mode} split:') for k, v in result.items(): if k == "dynamic_k": - self.log(f"dyn_k_{self.self.concept_names[i]}", v, **self.log_kwargs) + self.log(f"dyn_k_{self.concept_names[i]}", v, **self.log_kwargs) indexed_v = (v - self.dynamic_k_baselines[mode][self.concept_names[i]]) / ( 100 - self.dynamic_k_baselines[mode][self.concept_names[i]] ) From 3cef11c6bdf44c062573c1ea0cfc5894e60c5177 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 8 Apr 2026 16:57:28 +0200 Subject: [PATCH 66/78] introduce identity encoder issue #65 --- .../geo_encoders/identity_encoder.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 src/models/components/geo_encoders/identity_encoder.py diff --git a/src/models/components/geo_encoders/identity_encoder.py b/src/models/components/geo_encoders/identity_encoder.py new file mode 100644 index 0000000..29f9bfe --- /dev/null +++ b/src/models/components/geo_encoders/identity_encoder.py @@ -0,0 +1,50 @@ +from typing import Dict, List, override + +import torch +from torch import nn + +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder + + +class IdentityEncoder(BaseGeoEncoder): + def __init__( + self, + geo_data_name="aef", + ) -> None: + """Encoder to avreage tile values into a 1D vector. + + :param geo_data_name: modality name + """ + super().__init__() + + self.dict_n_bands_default = {"s2": 4, "aef": 64, "tessera": 128} + self.allowed_geo_data_names: list[str] = list(self.dict_n_bands_default.keys()) + assert ( + geo_data_name in self.allowed_geo_data_names + ), f"geo_data_name must be one of {self.allowed_geo_data_names}, got {geo_data_name}" + self.geo_data_name = geo_data_name + + @override + def _setup(self) -> List[str]: + """Configures modules and returns newly initialised, trainable module names.""" + + self.output_dim = None + self.geo_encoder = nn.Identity() + return [] + + @override + def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + """Data forward pass through the encoder.""" + tile = batch.get("eo", {}).get(self.geo_data_name) + + if self.output_dim is None: + self.output_dim = tile.shape + + feats = self.geo_encoder(tile) + if self.extra_projector: + feats = self.extra_projector(feats) + return feats + + @property + def device(self): + return From c7fc5f9b3d590c3073d95c9400171162fd0956f9 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 8 Apr 2026 16:58:03 +0200 Subject: [PATCH 67/78] Add setup for text encoder --- .../text_encoders/base_text_encoder.py | 29 +++++++++++++++++-- .../text_encoders/clip_text_encoder.py | 2 -- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/models/components/text_encoders/base_text_encoder.py b/src/models/components/text_encoders/base_text_encoder.py index d934aab..4f54a2f 100644 --- a/src/models/components/text_encoders/base_text_encoder.py +++ b/src/models/components/text_encoders/base_text_encoder.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict +from typing import Dict, List, final import torch from torch import nn @@ -8,12 +8,37 @@ class BaseTextEncoder(nn.Module, ABC): def __init__(self) -> None: super().__init__() + + # modules self.processor: nn.Module | None = None self.model: nn.Module = None self.projector: nn.Module | None = None - self.output_dim: int | None = None self.extra_projector: nn.Module | None = None + self.output_dim: int | None = None + self.setup_flag: bool = False + self.cfg_dict: Dict = {} + + @final + def setup(self): + """Configures modules. + + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. + """ + if self.setup_flag: + print(f"Module {self.__str__()} is already set up.") + return [] + else: + trainable_modules = self._setup() + print(f"Model set up with {self.__str__()}") + self.setup_flag = True + return trainable_modules + + def _setup(self) -> List[str]: + """Configures modules and returns newly initialised, trainable module names.""" + return [] + @abstractmethod def forward(self, batch: Dict[str, torch.Tensor], mode: str) -> torch.Tensor: pass diff --git a/src/models/components/text_encoders/clip_text_encoder.py b/src/models/components/text_encoders/clip_text_encoder.py index 075a390..5052ca5 100644 --- a/src/models/components/text_encoders/clip_text_encoder.py +++ b/src/models/components/text_encoders/clip_text_encoder.py @@ -29,8 +29,6 @@ def __init__(self, hf_cache_dir: str = "../.cache", output_normalization="l2") - self.output_dim = 512 - print("Model set up with CLIP text encoder") - @override def forward(self, batch: Dict[str, torch.Tensor], mode: str) -> torch.Tensor: # Get text inputs From 3e78fae054325ab2f48b3c2b8e0f608dece7707d Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 8 Apr 2026 17:00:05 +0200 Subject: [PATCH 68/78] Extract _setup method, add setup_flag, add update_configs method --- .../geo_encoders/average_encoder.py | 14 +++--- .../geo_encoders/base_geo_encoder.py | 47 ++++++++++++++----- .../components/geo_encoders/cnn_encoder.py | 11 +---- .../geo_encoders/encoder_wrapper.py | 27 ++++++++++- src/models/components/geo_encoders/geoclip.py | 3 +- .../components/geo_encoders/mlp_projector.py | 3 +- .../geo_encoders/tabular_encoder.py | 3 +- 7 files changed, 72 insertions(+), 36 deletions(-) diff --git a/src/models/components/geo_encoders/average_encoder.py b/src/models/components/geo_encoders/average_encoder.py index b11e3ce..7e31b44 100644 --- a/src/models/components/geo_encoders/average_encoder.py +++ b/src/models/components/geo_encoders/average_encoder.py @@ -19,23 +19,17 @@ def __init__( self.dict_n_bands_default = {"s2": 4, "aef": 64, "tessera": 128} self.allowed_geo_data_names: list[str] = list(self.dict_n_bands_default.keys()) - assert ( geo_data_name in self.allowed_geo_data_names ), f"geo_data_name must be one of {self.allowed_geo_data_names}, got {geo_data_name}" self.geo_data_name = geo_data_name @override - def setup(self) -> List[str]: - """Configures networks, data-dependent parts. + def _setup(self) -> List[str]: + """Configures modules and returns newly initialised, trainable module names.""" - Gets called in model.setup() method. Returns names of any new module configured to be added - to the trainable modules list. - """ self.output_dim = self.dict_n_bands_default[self.geo_data_name] self.geo_encoder = nn.Identity() - print(f"Model set up with average geo-encoder for {self.geo_data_name}") - return [] @override @@ -46,3 +40,7 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: if self.extra_projector: feats = self.extra_projector(feats) return feats + + @property + def device(self): + return diff --git a/src/models/components/geo_encoders/base_geo_encoder.py b/src/models/components/geo_encoders/base_geo_encoder.py index 6b65c75..167e53f 100644 --- a/src/models/components/geo_encoders/base_geo_encoder.py +++ b/src/models/components/geo_encoders/base_geo_encoder.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict +from typing import Dict, List, final import torch from torch import nn @@ -8,13 +8,44 @@ class BaseGeoEncoder(nn.Module, ABC): def __init__(self) -> None: super().__init__() + + # Modules self.geo_encoder: nn.Module | None = None + self.extra_projector: nn.Module | None = None + self.output_dim: int | None = None + self.setup_flag: bool = False + self.cfg_dict: Dict = {} - # placeholders self.allowed_geo_data_names: list[str] | None = None self.geo_data_name: str | None = None - self.extra_projector: nn.Module | None = None + + def update_configs(self, cfg): + if len(self.cfg_dict) == 0: + self.cfg_dict = cfg + else: + print("Configs for geo encoder not updated") + + @final + def setup(self) -> List[str]: + """Configures modules. + + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. + """ + if self.setup_flag: + print(f"Module {self.__str__()} is already set up.") + return [] + else: + trainable_modules = self._setup() + print(f"Model set up with {self.__str__()}") + self.setup_flag = True + return trainable_modules + + @abstractmethod + def _setup(self) -> List[str]: + """Configures modules and returns newly initialised, trainable module names.""" + pass @abstractmethod def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: @@ -38,15 +69,7 @@ def dtype(self) -> torch.dtype | None: return None return dtypes.pop() - @abstractmethod - def setup(self) -> list[str]: - """Configures networks, data-dependent parts. - - Gets called in model.setup() method. Returns names of any new module configured to be added - to the trainable modules list. - """ - pass - + @final def add_projector(self, projected_dim: int) -> None: """Adds an extra linear projection layer to the geo encoder. diff --git a/src/models/components/geo_encoders/cnn_encoder.py b/src/models/components/geo_encoders/cnn_encoder.py index 6ec3609..4ece11d 100644 --- a/src/models/components/geo_encoders/cnn_encoder.py +++ b/src/models/components/geo_encoders/cnn_encoder.py @@ -45,8 +45,6 @@ def __init__( ), f"input_n_bands must be int >=3, got {self.input_n_bands}" self.output_dim = output_dim - self.geo_encoder = self.get_backbone() - def set_n_input_bands(self, n_bands: int | None = None) -> None: """Sets number of input bands based on geo_data_name if n_bands is None. @@ -132,10 +130,9 @@ def get_backbone(self): raise ValueError(f"Unsupported backbone: {self.backbone}") @override - def setup(self) -> List[str]: + def _setup(self) -> List[str]: # TODO: could you make sure new layers are returned here to be added to trainable parts? - # Maybe move the get_backbone method in here? - print(f"Model setup with cnn geo-encoder for {self.geo_data_name}") + self.geo_encoder = self.get_backbone() return [] @override @@ -163,7 +160,3 @@ def forward( feats = self.extra_projector(feats) return feats.to(dtype) - - -if __name__ == "__main__": - _ = CNNEncoder(None, None, None, None, None, None, None, None) diff --git a/src/models/components/geo_encoders/encoder_wrapper.py b/src/models/components/geo_encoders/encoder_wrapper.py index 0523619..e9d7fff 100644 --- a/src/models/components/geo_encoders/encoder_wrapper.py +++ b/src/models/components/geo_encoders/encoder_wrapper.py @@ -4,7 +4,9 @@ import torch.nn as nn from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder +from src.models.components.geo_encoders.identity_encoder import IdentityEncoder from src.models.components.geo_encoders.tabular_encoder import TabularEncoder +from src.utils.errors import IllegalArgumentCombination class EncoderWrapper(BaseGeoEncoder): @@ -39,7 +41,26 @@ def _reformat_set_branches(self, encoder_branches: List[Dict[str, Any]]): self.encoder_branches.append(module_dict) @override - def setup(self) -> List[str]: + def update_configs(self, cfg): + """Update model configurations.""" + # If adopted encoder -> it should already saved the configs + if ( + cfg["_target_"] == "src.models.components.geo_encoders.adopt_encoder.adopt_encoder" + and len(self.cfg_dict) != 0 + ): + return + + for i, branch in enumerate(cfg["encoder_branches"]): + if ( + branch["encoder"]["_target_"] + == "src.models.components.geo_encoders.adopt_encoder.adopt_encoder" + ): + branch["encoder"] = self.encoder_branches[i]["encoder"].cfg_dict + + self.cfg_dict = cfg + + @override + def _setup(self) -> List[str]: new_modules = [] # Configure/initialise missing/conditional parts @@ -62,6 +83,10 @@ def setup(self) -> List[str]: # Configure adapter/projector if requested if "projector" in branch: + if isinstance(encoder, IdentityEncoder): + raise IllegalArgumentCombination( + "Identity encoder cannot have linear projector" + ) projector = branch["projector"] intermediate_dim = encoder.output_dim diff --git a/src/models/components/geo_encoders/geoclip.py b/src/models/components/geo_encoders/geoclip.py index cd246ef..9224d60 100644 --- a/src/models/components/geo_encoders/geoclip.py +++ b/src/models/components/geo_encoders/geoclip.py @@ -21,10 +21,9 @@ def __init__( self.geo_data_name = geo_data_name @override - def setup(self) -> List[str]: + def _setup(self) -> List[str]: self.geo_encoder = LocationEncoder() self.output_dim = self.geo_encoder.LocEnc0.head[0].out_features - print("Model setup with GeoClip coordinate encoder") return [] @override diff --git a/src/models/components/geo_encoders/mlp_projector.py b/src/models/components/geo_encoders/mlp_projector.py index 4ab6e05..a43316e 100644 --- a/src/models/components/geo_encoders/mlp_projector.py +++ b/src/models/components/geo_encoders/mlp_projector.py @@ -25,9 +25,8 @@ def __init__( self.net: nn.Module | None = None @override - def setup(self) -> List[str]: + def _setup(self) -> List[str]: self.configure_nn() - print("Model set up with MLP projector") return ["net"] def set_input_dim(self, input_dim: int) -> None: diff --git a/src/models/components/geo_encoders/tabular_encoder.py b/src/models/components/geo_encoders/tabular_encoder.py index 314cf29..faff1a4 100644 --- a/src/models/components/geo_encoders/tabular_encoder.py +++ b/src/models/components/geo_encoders/tabular_encoder.py @@ -33,9 +33,8 @@ def __init__( self.geo_data_name = geo_data_name @override - def setup(self, input_dim: int = None) -> list[str]: + def _setup(self, input_dim: int = None) -> list[str]: self.configure_nn(input_dim) - print("Model setup with Tabular geo-encoder") return ["tabular_encoder"] def set_tabular_input_dim(self, input_dim: int) -> None: From bf7394fea414a6a773bb4e99a1ece66c07f914ca Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 8 Apr 2026 17:00:38 +0200 Subject: [PATCH 69/78] Update configs from the checkpoint --- src/models/components/geo_encoders/adopt_encoder.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/models/components/geo_encoders/adopt_encoder.py b/src/models/components/geo_encoders/adopt_encoder.py index f6e9f76..2a5bd88 100644 --- a/src/models/components/geo_encoders/adopt_encoder.py +++ b/src/models/components/geo_encoders/adopt_encoder.py @@ -21,6 +21,7 @@ def adopt_encoder(ckpt_path: str) -> BaseGeoEncoder: encoder: BaseGeoEncoder = hydra.utils.instantiate(geo_config) print("---Adopted encoder------") encoder.setup() + encoder.cfg_dict = geo_config print("------------------------") # Load in the weights @@ -32,6 +33,4 @@ def adopt_encoder(ckpt_path: str) -> BaseGeoEncoder: res = encoder.load_state_dict(state_dict, strict=False) log_model_loading("geo_encoder_ckpt", res) - encoder.setup = lambda *args, **kwargs: None # TODO: switch to maybe self.setup flag - return encoder From a504c6df5ccfd2de61bff2a3765c9ea1b335716e Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 8 Apr 2026 17:01:00 +0200 Subject: [PATCH 70/78] Extract _setup method, add setup_flag --- .../components/pred_heads/base_pred_head.py | 25 ++++++++++++++++--- .../components/pred_heads/linear_pred_head.py | 12 +++------ .../components/pred_heads/mlp_pred_head.py | 9 ++----- .../pred_heads/mlp_regression_head.py | 9 ++----- 4 files changed, 28 insertions(+), 27 deletions(-) diff --git a/src/models/components/pred_heads/base_pred_head.py b/src/models/components/pred_heads/base_pred_head.py index 5ed63c4..b3acd0c 100644 --- a/src/models/components/pred_heads/base_pred_head.py +++ b/src/models/components/pred_heads/base_pred_head.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import final +from typing import List, final import torch from torch import nn @@ -9,9 +9,14 @@ class BasePredictionHead(nn.Module, ABC): def __init__(self) -> None: """Base prediction head interface class.""" super().__init__() + + # Modules self.net: nn.Module | None = None + self.input_dim: int | None = None self.output_dim: int | None = None + self.setup_flag: bool = False + self.cfg_dict = {} @abstractmethod def forward(self, feats: torch.Tensor) -> torch.Tensor: @@ -34,11 +39,23 @@ def set_dim(self, input_dim: int, output_dim: int) -> None: self.input_dim = input_dim self.output_dim = output_dim - @abstractmethod - def setup(self) -> None: - """Configures networks, data-dependent parts. + @final + def setup(self) -> List[str]: + """Configures modules. Gets called in model.setup() method. Returns names of any new module configured to be added to the trainable modules list. """ + if self.setup_flag: + print(f"Module {self.__str__()} is already set up.") + return [] + else: + self._setup() + print(f"Model set up with {self.__str__()}") + self.setup_flag = True + return ["prediction_head"] + + @abstractmethod + def _setup(self) -> None: + """Configures specific prediction head.""" pass diff --git a/src/models/components/pred_heads/linear_pred_head.py b/src/models/components/pred_heads/linear_pred_head.py index 1dfa094..92c7818 100644 --- a/src/models/components/pred_heads/linear_pred_head.py +++ b/src/models/components/pred_heads/linear_pred_head.py @@ -1,4 +1,4 @@ -from typing import override +from typing import List, override import torch from torch import nn @@ -24,20 +24,14 @@ def __init__( @override def forward(self, feats: torch.Tensor) -> torch.Tensor: """Forward pass through the prediction head.""" - return torch.sigmoid(self.net(feats)) @override - def setup(self) -> None: - """Configures networks, data-dependent parts. - - Gets called in model.setup() method. Returns names of any new module configured to be added - to the trainable modules list. - """ + def _setup(self) -> None: + """Configures specific prediction head.""" assert type(self.input_dim) is int, self.input_dim assert type(self.output_dim) is int, self.output_dim self.net = nn.Linear(self.input_dim, self.output_dim) - print("Model set up with linear prediction head") return diff --git a/src/models/components/pred_heads/mlp_pred_head.py b/src/models/components/pred_heads/mlp_pred_head.py index 65610a5..7d4cd48 100644 --- a/src/models/components/pred_heads/mlp_pred_head.py +++ b/src/models/components/pred_heads/mlp_pred_head.py @@ -34,12 +34,8 @@ def forward(self, feats: torch.Tensor) -> torch.Tensor: return torch.sigmoid(self.net(feats)) @override - def setup(self) -> None: - """Configures networks, data-dependent parts. - - Gets called in model.setup() method. Returns names of any new module configured to be added - to the trainable modules list. - """ + def _setup(self) -> None: + """Configures specific prediction head.""" assert type(self.input_dim) is int, self.input_dim assert type(self.output_dim) is int, self.output_dim layers = [] @@ -50,7 +46,6 @@ def setup(self) -> None: input_dim = self.hidden_dim layers.append(nn.Linear(input_dim, self.output_dim)) self.net = nn.Sequential(*layers) - print("Model set up with MLP prediction head") return diff --git a/src/models/components/pred_heads/mlp_regression_head.py b/src/models/components/pred_heads/mlp_regression_head.py index 9679fa2..10088a1 100644 --- a/src/models/components/pred_heads/mlp_regression_head.py +++ b/src/models/components/pred_heads/mlp_regression_head.py @@ -49,12 +49,8 @@ def forward(self, feats: torch.Tensor) -> torch.Tensor: return self.net(feats) @override - def setup(self) -> None: - """Configures networks, data-dependent parts. - - Gets called in model.setup() method. Returns names of any new module configured to be added - to the trainable modules list. - """ + def _setup(self) -> None: + """Configures specific prediction head.""" assert isinstance(self.input_dim, int), self.input_dim assert isinstance(self.output_dim, int), self.output_dim @@ -70,5 +66,4 @@ def setup(self) -> None: layers.append(nn.Linear(in_dim, self.output_dim)) self.net = nn.Sequential(*layers) - print("Model set up with MLP regression prediction head") return From 16f495df3caa5ec320095ce2e02212fe961b8975 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 8 Apr 2026 17:01:44 +0200 Subject: [PATCH 71/78] Clean up models and add setup flags, configurations into the model checkpointing --- src/models/base_model.py | 153 +++++++++++++++++++++-------- src/models/inference_model.py | 109 +++++++++++--------- src/models/predictive_model.py | 91 +++++++---------- src/models/text_alignment_model.py | 86 ++++++---------- src/train.py | 2 +- 5 files changed, 244 insertions(+), 197 deletions(-) diff --git a/src/models/base_model.py b/src/models/base_model.py index 1acf647..e97b53c 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -4,24 +4,34 @@ import torch from lightning import LightningModule +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder from src.models.components.loss_fns.base_loss_fn import BaseLossFn from src.models.components.metrics.metrics_wrapper import MetricsWrapper +from src.models.components.pred_heads.base_pred_head import BasePredictionHead +from src.models.components.text_encoders.base_text_encoder import BaseTextEncoder +from src.utils.logging_utils import log_model_loading class BaseModel(LightningModule, ABC): def __init__( self, - trainable_modules: list[str] | None, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler, - loss_fn: BaseLossFn, - metrics: MetricsWrapper, - num_classes: int | None = None, - tabular_dim: int | None = None, + trainable_modules: list[str], + geo_encoder: BaseGeoEncoder, + text_encoder: BaseTextEncoder | None, + prediction_head: BasePredictionHead | None, + optimizer: torch.optim.Optimizer | None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None, + loss_fn: BaseLossFn | None, + metrics: MetricsWrapper | None, + num_classes: int | None, + tabular_dim: int | None, ) -> None: """Interface for any model. :param trainable_modules: which modules to train + :param geo_encoder: module for encoding geo data + :param text_encoder: module for encoding text data + :param prediction_head: module for making prediction from geo features :param optimizer: optimizer for the model weight update :param scheduler: scheduler for the model weight update :param loss_fn: loss function @@ -30,36 +40,69 @@ def __init__( :param tabular_dim: number of tabular features """ super().__init__() + + # Ignore objects self.save_hyperparameters( ignore=[ - "loss_fn", "geo_encoder", - "prediction_head", "text_encoder", - "metrics", + "prediction_head", "optimizer", "scheduler", + "loss_fn", + "metrics", ] ) self.trainable_modules = trainable_modules - self.num_classes = num_classes - self.tabular_dim = tabular_dim - self.loss_fn = loss_fn - self.metrics = metrics + self.geo_encoder = geo_encoder + if text_encoder: + self.text_encoder = text_encoder + if prediction_head: + self.prediction_head = prediction_head + self.optimizer = optimizer self.scheduler = scheduler - @abstractmethod + self.loss_fn = loss_fn + self.metrics = metrics + + self.num_classes = num_classes + self.tabular_dim = tabular_dim + + self.setup_flag = False + + @final def setup(self, stage: str) -> None: """Updates model based data-bound configurations (through datamodule), This method is called after trainer is initialized and datamodule is available.""" + if self.setup_flag: + print(f"Model {self.__str__()} is already set up!") + return + + # If trainer is attached get num_classes and tabular_dim from datamodule (data-dependent) + if self._trainer is not None: + self.num_classes = self.trainer.datamodule.num_classes + self.tabular_dim = self.trainer.datamodule.tabular_dim + + # Per model logic of setting up + self._setup(stage) + self.setup_flag = True + + # Freezing requested parts + if stage in ["inference", "test"]: + self.full_freezer() + self.trainable_modules = [] + else: + self.freezer() + + @abstractmethod + def _setup(self, stage: str) -> None: pass @final def full_freezer(self): """Freeze the whole network.""" - # Freeze the whole network for name, param in self.named_parameters(): param.requires_grad = False @@ -71,17 +114,16 @@ def full_freezer(self): @final def freezer(self) -> None: """Freezes modules based on provided trainable modules.""" - trainable_modules = tuple(self.trainable_modules) or tuple() - - # Store higher level module names for printing of trainable parts - trainable = set() + # Convert for checking with .startswith() + trainable_set = tuple(set(self.trainable_modules)) or tuple() + expanded_trainable = set() # Freeze modules for name, param in self.named_parameters(): - # Enable exceptions - if name.startswith(trainable_modules): + # Unfreeze trainable parts + if name.startswith(trainable_set): param.requires_grad = True - trainable.add(name) + expanded_trainable.add(name) else: # Freeze the rest param.requires_grad = False @@ -95,8 +137,8 @@ def freezer(self) -> None: # - it is the root module (""), which must be train when any child is. def _in_train_scope(name: str) -> bool: if not name: # root module - return bool(trainable_modules) - for t in trainable_modules: + return bool(trainable_set) + for t in trainable_set: if name == t or name.startswith(t + ".") or t.startswith(name + "."): return True return False @@ -108,9 +150,10 @@ def _in_train_scope(name: str) -> bool: module.eval() print("------Set to train------") - for m in sorted(trainable): + for m in sorted(expanded_trainable): print(f" {m}") print("------------------------") + self.trainable_modules = list(expanded_trainable) @abstractmethod def forward( @@ -167,14 +210,44 @@ def configure_optimizers(self) -> Dict[str, Any]: } return {"optimizer": optimizer} + def update_configs(self, cfg): + """Update hyper-parameters from the model.""" + self.geo_encoder.update_configs(cfg["geo_encoder"]) + + if hasattr(self, "text_encoder"): + self.text_encoder.cfg_dict = cfg["text_encoder"] + + if hasattr(self, "prediction_head"): + self.prediction_head.cfg_dict = cfg["prediction_head"] + def on_save_checkpoint(self, checkpoint): - """Save only trainable parts of the model.""" + """Save checkpoint. + + - Save only trainable parts of the model. + - Append configurations of the model + """ + if not self.setup_flag: + raise ValueError("Model cannot be saved as it was not set up.") + + # Remove unnecessary keys + pop_list = [ + "state_dict", + "loops", + "hparams_name", + "datamodule_hyper_parameters", + "datamodule_hparams_name", + ] + for i in pop_list: + checkpoint.pop(i) + + # Save only trainable parts checkpoint["state_dict"] = { k: v for k, v in self.state_dict().items() if any(k.startswith(part) for part in self.trainable_modules) } + # Update model configurations checkpoint["hyper_parameters"].update( { "num_classes": self.num_classes, @@ -183,22 +256,20 @@ def on_save_checkpoint(self, checkpoint): } ) - def on_load_checkpoint(self, checkpoint): - """Load only trainable parts of the model.""" - missing_keys, unexpected_keys = self.load_state_dict( - checkpoint["state_dict"], strict=False - ) - print("Model loaded from a checkpoint.") + checkpoint["hyper_parameters"]["geo_encoder"] = self.geo_encoder.cfg_dict + + if hasattr(self, "prediction_head"): + checkpoint["hyper_parameters"]["prediction_head"] = self.prediction_head.cfg_dict + if hasattr(self, "text_encoder"): + checkpoint["hyper_parameters"]["text_encoder"] = self.text_encoder.cfg_dict - # self.tabular_dim = checkpoint['hyper_parameters']["tabular_dim"] - # self.num_classes = checkpoint["hyper_parameters"]["num_classes"] + return - if missing_keys: - missing_keys = {".".join(i.split(".")[:3]) for i in missing_keys} - print(f"The following keys are missing from the pretrained model: {missing_keys}") - if unexpected_keys: - unexpected_keys = {".".join(i.split(".")[:3]) for i in unexpected_keys} - print(f"The following keys are unexpected from the pretrained model:{unexpected_keys}") + def on_load_checkpoint(self, checkpoint): + """Load pre-trained parts of the model.""" + res = self.load_state_dict(checkpoint["state_dict"], strict=False) + print("Model loaded from a checkpoint.") + log_model_loading("Model from checkpoint", res) # TODO feels illegal def load_state_dict(self, state_dict, strict=True): diff --git a/src/models/inference_model.py b/src/models/inference_model.py index 0fab831..bd47a6b 100644 --- a/src/models/inference_model.py +++ b/src/models/inference_model.py @@ -7,6 +7,7 @@ from src.models.base_model import BaseModel from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder +from src.models.components.metrics.metrics_wrapper import MetricsWrapper from src.models.components.pred_heads.linear_pred_head import BasePredictionHead from src.models.components.text_encoders.base_text_encoder import ( BaseTextEncoder, @@ -24,59 +25,60 @@ def __init__( geo_encoder: BaseGeoEncoder, text_encoder: BaseTextEncoder, prediction_head: BasePredictionHead, - num_classes: int | None = None, - tabular_dim: int | None = None, + num_classes: int | None, + metrics: MetricsWrapper | None = None, + ks: list[int] | None = [5, 10, 15], match_to_geo: bool = True, **kwargs, ) -> None: + """Inference model. + + :param geo_encoder: module for encoding geo data + :param text_encoder: module for encoding text data + :param prediction_head: module for making prediction from geo features + :param num_classes: number of target classes + :param metrics: metrics to track for model performance estimation + :param ks: list of ks + :param match_to_geo: whether to match dimensions of text encoder to geo_encoder or visa- + versa + """ + super().__init__( - [], None, None, None, None, num_classes=num_classes, tabular_dim=tabular_dim + trainable_modules=[], + geo_encoder=geo_encoder, + text_encoder=text_encoder, + prediction_head=prediction_head, + optimizer=None, + scheduler=None, + loss_fn=None, + metrics=metrics, + num_classes=num_classes, + tabular_dim=None, ) - # Encoders configuration - self.geo_encoder = geo_encoder - self.text_encoder = text_encoder - + # Params from alignment model self.match_to_geo = match_to_geo - - # Prediction head - self.prediction_head = prediction_head + self.ks = ks @override - def setup(self, stage: str) -> None: - # During inference we need to ensure: - # - geo_encoder is fully initialized (sets geo_encoder.output_dim) - # - text/geo dims match (possibly via text_encoder.extra_projector) - # - prediction_head.net is created with correct (input_dim, output_dim) + def _setup(self, stage: str) -> None: + """Set up the network.""" if stage != "inference": - return + raise ValueError(f"Trying to {stage} inference model") - # Configure geo encoder and its output_dim only if it wasn't already set up. - # (In the normal "stitch-from-ckpts" flow, the modules are already initialized.) - if getattr(self.geo_encoder, "output_dim", None) is None or ( - hasattr(self.geo_encoder, "geo_encoder") - and getattr(self.geo_encoder, "geo_encoder") is None - ): - self.geo_encoder.setup() + print("-------Model------------") + # Configure encoders + self.geo_encoder.setup() + self.text_encoder.setup() # Configure optional extra projection so text embeddings match geo embeddings. - # Note: current codebase applies extra projector on text encoders (not on geo encoders) - # during forward, so `match_to_geo` is expected to be True. if self.text_encoder.output_dim != self.geo_encoder.output_dim: - if not self.match_to_geo: - raise ValueError( - "match_to_geo=False is not supported for inference: geo extra projector " - "is not applied in geo encoder forward passes in this codebase." - ) - # If extra_projector already exists but output dims still mismatch, we recreate it. - # Otherwise, avoid overwriting weights unnecessarily. - if ( - getattr(self.text_encoder, "extra_projector", None) is None - or self.text_encoder.output_dim != self.geo_encoder.output_dim - ): + if self.match_to_geo: self.text_encoder.add_projector(projected_dim=self.geo_encoder.output_dim) + else: + self.geo_encoder.add_projector(projected_dim=self.text_encoder.output_dim) - # Configure prediction head (it creates `prediction_head.net` in setup()). + # Configure prediction head if self.prediction_head.net is None: if self.num_classes is None: raise ValueError( @@ -86,9 +88,7 @@ def setup(self, stage: str) -> None: input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes ) self.prediction_head.setup() - - # Freeze everything for pure inference. - self.full_freezer() + print("------------------------") @override def _step( @@ -179,6 +179,9 @@ def merge_inference_model(cfg, save_ckpt=False) -> InferenceModel | None: align_ckpt = torch.load(align_ckpt_path, map_location="cpu", weights_only=False) # Sanity check: ensure geo encoder configs match. + align_ckpt["hyper_parameters"]["geo_encoder"] = pred_ckpt["hyper_parameters"].get( + "geo_encoder" + ) if pred_ckpt["hyper_parameters"].get("geo_encoder") != align_ckpt["hyper_parameters"].get( "geo_encoder" ): @@ -186,9 +189,7 @@ def merge_inference_model(cfg, save_ckpt=False) -> InferenceModel | None: if input("Do you want to proceed? y/n").lower() == "n": return None - pred_trainable_modules = [ - pred_ckpt["hyper_parameters"].get("trainable_modules", [])[0] - ] # TODO fix + pred_trainable_modules = pred_ckpt["hyper_parameters"].get("trainable_modules", []) align_trainable_modules = align_ckpt["hyper_parameters"].get("trainable_modules", []) geo_pred_encoder_trained = _is_prefix_trained(pred_trainable_modules, "geo_encoder") @@ -213,16 +214,30 @@ def merge_inference_model(cfg, save_ckpt=False) -> InferenceModel | None: model: InferenceModel = hydra.utils.instantiate(inference_hparams) model.setup("inference") - # Load alignment weights first (text encoder + geo encoder). - res = model.load_state_dict(align_ckpt["state_dict"], strict=False) - log_model_loading("alignment_ckpt", res) + # Load alignment weights first (text encoder). + text_state = { + k: v for k, v in align_ckpt["state_dict"].items() if k.startswith("text_encoder.") + } + res = model.load_state_dict(text_state, strict=False) + log_model_loading("text_encoder", res) + + if cfg.training_order[0] == "prediction_model" and not geo_align_encoder_trained: + geo_state = { + k: v for k, v in pred_ckpt["state_dict"].items() if k.startswith("geo_encoder.") + } + else: + geo_state = { + k: v for k, v in align_ckpt["state_dict"].items() if k.startswith("geo_encoder.") + } + res = model.load_state_dict(geo_state, strict=False) + log_model_loading("geo_encoder", res) # Load prediction head weights from predictive ckpt. head_state = { k: v for k, v in pred_ckpt["state_dict"].items() if k.startswith("prediction_head.") } res = model.load_state_dict(head_state, strict=False) - log_model_loading("predictive_prediction_head_only", res) + log_model_loading("Predictive_head", res) # Save model if save_ckpt: diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index 5663afb..d71224f 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -21,79 +21,64 @@ def __init__( prediction_head: BasePredictionHead, trainable_modules: list[str], optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler, + scheduler: torch.optim.lr_scheduler.LRScheduler | None, loss_fn: BaseLossFn, metrics: MetricsWrapper, - normalize_features: bool = True, num_classes: int | None = None, tabular_dim: int | None = None, + normalize_features: bool = True, ) -> None: """Implementation of the predictive model with replaceable GEO encoder, and prediction head. - :param geo_encoder: geo encoder module (replaceable) - :param prediction_head: prediction head module (replaceable) - :param trainable_modules: list of modules to train (parts/modules or modules, modules) - :param optimizer: optimizer to use for training - :param scheduler: scheduler to use for training - :param loss_fn: loss function to use - :param metrics: metrics to use for model performance evaluation - :param normalize_features: if True, apply L2 normalisation to encoder output before the - prediction head (default: True) + :param trainable_modules: which modules to train + :param geo_encoder: module for encoding geo data + :param prediction_head: module for making prediction from geo features + :param optimizer: optimizer for the model weight update + :param scheduler: scheduler for the model weight update + :param loss_fn: loss function + :param metrics: metrics to track for model performance estimation :param num_classes: number of target classes :param tabular_dim: number of tabular features + :param normalize_features: if True, apply L2 normalisation to encoder output before the + prediction head (default: True) """ super().__init__( - trainable_modules, optimizer, scheduler, loss_fn, metrics, num_classes, tabular_dim + trainable_modules=trainable_modules, + geo_encoder=geo_encoder, + text_encoder=None, + prediction_head=prediction_head, + optimizer=optimizer, + scheduler=scheduler, + loss_fn=loss_fn, + metrics=metrics, + num_classes=num_classes, + tabular_dim=tabular_dim, ) - # Geo encoder configuration - self.geo_encoder = geo_encoder - - # Prediction head - self.prediction_head = prediction_head - # Normalise features boolean self.normalize_features = normalize_features @override - def setup(self, stage: str) -> None: - """Updates model based data-bound configurations (through datamodule), This method is - called after trainer is initialized and datamodule is available. + def _setup(self, stage: str) -> None: + """Set up encoders and missing adapters/projectors based data-bound configurations (through + datamodule), This method is called after trainer is initialized and datamodule is + available. Otherwise, some configuration variables must be made available """ - - if self._trainer is not None: - self.num_classes = self.trainer.datamodule.num_classes - self.tabular_dim = self.trainer.datamodule.tabular_dim - - if stage != "fit": - if isinstance(self.trainable_modules, tuple): - self.trainable_modules = list(self.trainable_modules) + if stage != "fit" and isinstance(self.trainable_modules, tuple): + self.trainable_modules = list(self.trainable_modules) print("-------Model------------") - self.setup_encoders_adapters() - print("------------------------") - - # Freezing requested parts - if stage in ["inference", "test"]: - self.full_freezer() - else: - self.freezer() - - def setup_encoders_adapters(self): - """Set up encoders and missing adapters/projectors.""" - # TODO: move to multi-modal eo encoder - # If tabular encoder used, we need to specify tabular dim if isinstance(self.geo_encoder, TabularEncoder) or isinstance( self.geo_encoder, EncoderWrapper ): self.geo_encoder.set_tabular_input_dim(self.tabular_dim) - # Setup encoders that need data-depended configurations + # Setup encoders new_modules = [f"geo_encoder.{i}" for i in self.geo_encoder.setup() or []] self.trainable_modules.extend(new_modules) @@ -101,12 +86,12 @@ def setup_encoders_adapters(self): self.prediction_head.set_dim( input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes ) - self.prediction_head.setup() - if "prediction_head" not in self.trainable_modules: - self.trainable_modules.append("prediction_head") + self.trainable_modules.extend(self.prediction_head.setup() or []) + print("------------------------") @override def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + """Forward pass of a batch through the model.""" feats = self.geo_encoder(batch) if self.normalize_features: feats = F.normalize(feats, dim=-1) @@ -114,19 +99,19 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: @override def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train") -> torch.Tensor: + """Step logic of forward pass, metric calculation.""" + + # Forward pass preds = self.forward(batch) + loss = self.loss_fn(preds, batch.get("target")) + metrics = self.metrics(pred=preds, batch=batch, mode=mode) + + # Logging log_kwargs = dict( on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=preds.size(0) ) - loss = self.loss_fn(preds, batch.get("target")) self.log(f"{mode}_loss", loss, **log_kwargs) - - metrics = self.metrics(pred=preds, batch=batch, mode=mode) self.log_dict(metrics, **log_kwargs) return loss - - -if __name__ == "__main__": - _ = PredictiveModel(None, None, None, None, None, None, None) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 397da7d..be3f072 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -1,4 +1,3 @@ -from io import text_encoding from typing import Dict, Tuple, override import numpy as np @@ -12,9 +11,6 @@ RetrievalContrastiveValidation, ) from src.models.components.metrics.metrics_wrapper import MetricsWrapper -from src.models.components.pred_heads.linear_pred_head import ( - BasePredictionHead, -) from src.models.components.text_encoders.base_text_encoder import ( BaseTextEncoder, ) @@ -23,81 +19,64 @@ class TextAlignmentModel(BaseModel): def __init__( self, + trainable_modules: list[str], geo_encoder: BaseGeoEncoder, text_encoder: BaseTextEncoder, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, loss_fn: BaseLossFn, - trainable_modules: list[str], metrics: MetricsWrapper, - prediction_head: BasePredictionHead | None = None, - ks: list[int] | None = [5, 10, 15], - match_to_geo: bool = True, num_classes: int | None = None, tabular_dim: int | None = None, + ks: list[int] | None = [5, 10, 15], + match_to_geo: bool = True, ) -> None: """Implementation of contrastive text-eo modality alignment model. - :param geo_encoder: geo encoder module (replaceable) - :param text_encoder: text encoder module (replaceable) - :param optimizer: optimizer to use for training - :param scheduler: scheduler to use for training - :param loss_fn: loss function to use (contrastive) - :param trainable_modules: list of modules to train (parts/modules or modules, modules) - :param metrics: metrics to use for model performance evaluation - :param prediction_head: prediction head + :param trainable_modules: which modules to train + :param geo_encoder: module for encoding geo data + :param text_encoder: module for encoding text data + :param optimizer: optimizer for the model weight update + :param scheduler: scheduler for the model weight update + :param loss_fn: loss function + :param metrics: metrics to track for model performance estimation + :param num_classes: number of target classes + :param tabular_dim: number of tabular features :param ks: list of ks :param match_to_geo: whether to match dimensions of text encoder to geo_encoder or visa- versa - :param num_classes: number of target classes - :param tabular_dim: number of tabular features """ super().__init__( - trainable_modules, optimizer, scheduler, loss_fn, metrics, num_classes, tabular_dim + trainable_modules=trainable_modules, + geo_encoder=geo_encoder, + text_encoder=text_encoder, + prediction_head=None, + optimizer=optimizer, + scheduler=scheduler, + loss_fn=loss_fn, + metrics=metrics, + num_classes=num_classes, + tabular_dim=tabular_dim, ) # Metrics self.ks = ks self.log_kwargs = dict(on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) - # Encoders configuration - self.geo_encoder = geo_encoder - self.text_encoder = text_encoder self.match_to_geo = match_to_geo - # Prediction head - self.prediction_head = prediction_head - @override - def setup(self, stage: str = "fit") -> None: - """Updates model based data-bound configurations (through datamodule), This method is - called after trainer is initialized and datamodule is available. + def _setup(self, stage: str = "fit") -> None: + """Set up encoders and missing adapters/projectors based data-bound configurations (through + datamodule), This method is called after trainer is initialized and datamodule is + available. Otherwise, some configuration variables must be made available """ - - if self._trainer is not None: - self.num_classes = self.trainer.datamodule.num_classes - self.tabular_dim = self.trainer.datamodule.tabular_dim - # Set up encoders and missing adapters/projectors print("-------Model------------") - self.setup_encoders_adapters() - print("------------------------") - - # Freeze not requested parts - if stage in ["inference", "test"]: - self.full_freezer() - else: - self.freezer() - - # Configure contrastive retrieval evaluation - self.setup_retrieval_evaluation() - - def setup_encoders_adapters(self): - """Set up encoders and missing adapters/projectors.""" - # Setup encoders that need data-depended configurations - new_modules = [f"geo_encoder.{i}" for i in self.geo_encoder.setup()] + new_modules = [f"geo_encoder.{i}" for i in self.geo_encoder.setup() or []] + new_modules.extend([f"text_encoder.{i}" for i in self.text_encoder.setup() or []]) self.trainable_modules.extend(new_modules) # Extra projector for text encoder if eo and text dim not match @@ -109,12 +88,9 @@ def setup_encoders_adapters(self): self.geo_encoder.add_projector(projected_dim=self.text_encoder.output_dim) self.trainable_modules.append("geo_encoder.extra_projector") - # Configure prediction head based on geo-encoder output_dim - if self.prediction_head is not None: - self.prediction_head.set_dim( - input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes - ) - self.prediction_head.setup() + # Configure contrastive retrieval evaluation + self.setup_retrieval_evaluation() + print("------------------------") def setup_retrieval_evaluation(self): self.concept_configs = self.trainer.datamodule.concept_configs diff --git a/src/train.py b/src/train.py index 7a4992c..af1686f 100644 --- a/src/train.py +++ b/src/train.py @@ -55,7 +55,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: # Append model hparams from config to be saved in ckpg raw_model_cfg = OmegaConf.to_container(cfg.model, resolve=True) - model.hparams.update(raw_model_cfg) + model.update_configs(raw_model_cfg) log.info("Instantiating callbacks...") callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) From 3542d1458e784618547514176c92589a4f35843d Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 8 Apr 2026 17:02:03 +0200 Subject: [PATCH 72/78] Add another option for log dir --- configs/paths/shared.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/paths/shared.yaml b/configs/paths/shared.yaml index 44218a8..26c1a8c 100644 --- a/configs/paths/shared.yaml +++ b/configs/paths/shared.yaml @@ -9,7 +9,7 @@ data_dir: ${oc.env:DATA_DIR,${oc.env:SHARED_ROOT,${paths.root_dir}}/data} cache_dir: ${oc.env:SHARED_CACHE,${paths.data_dir}/cache} # path to logging directory -log_dir: ${oc.env:SHARED_ROOT}/logs/ +log_dir: ${oc.env:SHARED_ROOT,${paths.root_dir}}/logs/ # path to output directory, created dynamically by hydra # path generation pattern is specified in `configs/hydra/local.yaml` From f94a3bb2955ecd51233bbbe16fd358f7b9c2f3f9 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 9 Apr 2026 11:07:49 +0200 Subject: [PATCH 73/78] Fix tests about missmatch in dim --- tests/test_pred_heads.py | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/tests/test_pred_heads.py b/tests/test_pred_heads.py index abc99b4..6f617c2 100644 --- a/tests/test_pred_heads.py +++ b/tests/test_pred_heads.py @@ -16,6 +16,7 @@ # @pytest.mark.slow def test_pred_head_generic_properties(create_butterfly_dataset): + """Test required properties of the prediction head class.""" ds, dm = create_butterfly_dataset batch = next(iter(dm.train_dataloader())) eo_encoder = GeoClipCoordinateEncoder() @@ -23,41 +24,34 @@ def test_pred_head_generic_properties(create_butterfly_dataset): feats = eo_encoder.forward(batch) list_pred_heads = [LinearPredictionHead, MLPPredictionHead, MLPRegressionPredictionHead] - for pred_head_class in list_pred_heads: - pred_head = pred_head_class(input_dim=64, output_dim=64) - pred_head.setup() - assert hasattr( - pred_head, "set_dim" - ), f"'set_dim' method missing in {pred_head_class.__name__}." + for pred_head in list_pred_heads: + pred_head = pred_head() + assert hasattr(pred_head, "set_dim"), f"'set_dim' method missing in {pred_head.__name__}." assert callable( getattr(pred_head, "set_dim") - ), f"'set_dim' is not callable in {pred_head_class.__name__}." - pred_head.set_dim(eo_encoder.output_dim, ds.num_classes) + ), f"'set_dim' is not callable in {pred_head.__name__}." + pred_head.set_dim(input_dim=eo_encoder.output_dim, output_dim=ds.num_classes) assert hasattr( pred_head, "input_dim" - ), f"'input_dim' attribute missing in {pred_head_class.__name__}." + ), f"'input_dim' attribute missing in {pred_head.__name__}." assert hasattr( pred_head, "output_dim" - ), f"'output_dim' attribute missing in {pred_head_class.__name__}." - assert hasattr( - pred_head, "setup" - ), f"'setup' method missing in {pred_head_class.__name__}." + ), f"'output_dim' attribute missing in {pred_head.__name__}." + assert hasattr(pred_head, "setup"), f"'setup' method missing in {pred_head.__name__}." assert callable( getattr(pred_head, "setup") - ), f"'setup' is not callable in {pred_head_class.__name__}." + ), f"'setup' is not callable in {pred_head.__name__}." pred_head.setup() - assert hasattr(pred_head, "net"), f"'net' attribute missing in {pred_head_class.__name__}." - assert hasattr( - pred_head, "forward" - ), f"'forward' method missing in {pred_head_class.__name__}." + assert hasattr(pred_head, "net"), f"'net' attribute missing in {pred_head.__name__}." + assert hasattr(pred_head, "forward"), f"'forward' method missing in {pred_head.__name__}." assert callable( getattr(pred_head, "forward") - ), f"'forward' is not callable in {pred_head_class.__name__}." + ), f"'forward' is not callable in {pred_head.__name__}." out = pred_head.forward(feats) assert isinstance( out, torch.Tensor - ), f"'forward' method of {pred_head_class.__name__} does not return a torch.Tensor." + ), f"'forward' method of {pred_head.__name__} does not return a torch.Tensor." assert out.shape == ( dm.batch_size_per_device, ds.num_classes, - ), f"Output shape mismatch in {pred_head_class.__name__}." + ), f"Output shape mismatch in {pred_head.__name__}." From 15229a618f1ea3d56610f6fc97caba734703a0ea Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Fri, 10 Apr 2026 14:36:17 +0200 Subject: [PATCH 74/78] Updates to tessera downloading (for larger areas, saves tiles with 0s for missing data) Includes data checks --- src/data_preprocessing/tessera_data_check.py | 57 ++++++++++++++ src/data_preprocessing/tessera_embeds.py | 81 ++++++++++++++++---- 2 files changed, 122 insertions(+), 16 deletions(-) create mode 100644 src/data_preprocessing/tessera_data_check.py diff --git a/src/data_preprocessing/tessera_data_check.py b/src/data_preprocessing/tessera_data_check.py new file mode 100644 index 0000000..a7c9801 --- /dev/null +++ b/src/data_preprocessing/tessera_data_check.py @@ -0,0 +1,57 @@ +import glob +import os + +import numpy as np + + +def center_crop(arr, target_shape): + slices = [] + for dim, target in zip(arr.shape, target_shape): + start = (dim - target) // 2 + end = start + target + slices.append(slice(start, end)) + return arr[tuple(slices)] + + +def main(paths): + + sizes = [256, 128] + + for p in paths: + img = np.load(p) + for s in sizes: + p_id = os.path.basename(p).split(".")[0].split("-")[-1] + crop = img + if s != img.shape[0]: + crop = center_crop(img, (s, s, 128)) + + if crop.shape[0:2] != (s, s): + with open(f"logs/tessera_size_mismatch_{s}.txt", "a") as f: + print(f"{p_id} has shape {crop.shape[0:2]}") + f.write(f"{p_id}\n") + + if np.isinf(crop.any()): + with open(f"logs/tessera_nans_{s}.txt", "a") as f: + f.write(f"{p_id}\n") + + nulls = np.sum(crop == 0) + if nulls > (s * s * 128) * 0.5: + with open(f"logs/tessera_50per_empty_{s}.txt", "a") as f: + print(f"50% of {p_id} is 0") + f.write(f"{p_id}\n") + + if nulls > (s * s * 128) * 0.25: + with open(f"logs/tessera_25per_empty_{s}.txt", "a") as f: + print(f"25% of {p_id} is 0") + f.write(f"{p_id}\n") + + +if __name__ == "__main__": + os.chdir("../..") + print(os.getcwd()) + + paths = glob.glob("/lustre/backup/SHARED/AIN/aether/data/s2bms/eo/tessera/*.npy") + paths = glob.glob("data/s2bms/eo/tessera/*.npy") + paths.sort() + + main(paths) diff --git a/src/data_preprocessing/tessera_embeds.py b/src/data_preprocessing/tessera_embeds.py index 1e1830d..2774687 100644 --- a/src/data_preprocessing/tessera_embeds.py +++ b/src/data_preprocessing/tessera_embeds.py @@ -4,6 +4,7 @@ import threading import numpy as np +from affine import Affine # Serialises concurrent reads/writes to the per-directory meta.csv log file. _meta_csv_lock = threading.Lock() @@ -17,7 +18,6 @@ from rasterio.warp import Resampling, calculate_default_transform, reproject from src.data_preprocessing.crs_utils import ( - create_bbox_with_radius, crs_to_pixel_coords, get_point_utm_crs, point_reprojection, @@ -75,6 +75,37 @@ def reproject_dataset(src_raster: MemoryFile, dst_crs: str) -> MemoryFile: return dst, memfile +def get_tiles(lat_center, lon_center, half_size_m=75, year=2024): + """Find all 0.1deg tiles (referenced at 0.05deg) that overlap AOI.""" + + # Convert half-size from meters to degrees (approximate) + half_lat_deg = half_size_m / 111320.0 + half_lon_deg = half_size_m / (111320.0 * np.cos(np.radians(lat_center))) + + # AOI bounds + lat_min = lat_center - half_lat_deg + lat_max = lat_center + half_lat_deg + lon_min = lon_center - half_lon_deg + lon_max = lon_center + half_lon_deg + + tile_size = 0.1 + + # Find tile indices overlapping the AOI + i_min = int(np.floor(lat_min / tile_size)) + i_max = int(np.floor(lat_max / tile_size)) + j_min = int(np.floor(lon_min / tile_size)) + j_max = int(np.floor(lon_max / tile_size)) + + tiles = [] + for i in range(i_min, i_max + 1): + for j in range(j_min, j_max + 1): + ref_lat = i * tile_size + 0.05 # tile reference (center) + ref_lon = j * tile_size + 0.05 + tiles.append((year, round(ref_lon, 10), round(ref_lat, 10))) + + return tiles + + def get_tessera_embeds( lon: float, lat: float, @@ -107,13 +138,10 @@ def get_tessera_embeds( utm_crs = get_point_utm_crs(lon, lat) lon_utm, lat_utm = point_reprojection(lon, lat, "EPSG:4326", utm_crs) - # Bounding box - radius = math.ceil(tile_size / 2) + padding - bbox = create_bbox_with_radius(lon, lat, radius=radius, utm_crs=utm_crs, return_wgs=True) - # Request to tessera - tiles_to_fetch = tessera_con.registry.load_blocks_for_region( - bounds=bbox.bounds, year=int(year) + radius = math.ceil(tile_size / 2) + padding + tiles_to_fetch = get_tiles( + lat_center=lat, lon_center=lon, half_size_m=radius * 10, year=int(year) ) # Mosaic returned tiles for the bbox @@ -185,6 +213,25 @@ def get_tessera_embeds( np.save(embed_tile_name, crop) print(f"Array saved as {embed_tile_name}") + # Temp save tif too + crop_transform = mosaic_transform * Affine.translation(col_min, row_min) + height, width, channels = crop.shape + + with rasterio.open( + embed_tile_name[:-4] + ".tif", + "w", + driver="GTiff", + height=height, + width=width, + count=channels, + dtype=crop.dtype, + crs=utm_crs, + transform=crop_transform, + ) as dst: + for i in range(channels): + dst.write(crop[:, :, i], i + 1) + print(f"tif saved to {embed_tile_name[:-4]}.tif") + # Log its metadata meta_df = pd.DataFrame( {"id": [name_loc], "year": [year], "lon": [lon], "lat": [lat], "crs": [utm_crs]} @@ -228,8 +275,6 @@ def tessera_from_df( n = len(model_ready_df) for i, row in model_ready_df.iterrows(): print(f"{i}/{n}") - if i < 238 or i in [1319]: - continue try: get_tessera_embeds(row.lon, row.lat, row.name_loc, year, f"{data_dir}/", tile_size, gt) except Exception as e: @@ -302,20 +347,24 @@ def inspect_np_arr_as_tiff( if __name__ == "__main__": + # os.chdir('../..') + print(os.getcwd()) # df = pd.read_csv("data/heat_guatemala/model_ready_heat_guatemala.csv") - df = pd.read_csv("/lustre/backup/SHARED/AIN/aether/data/s2bms/model_ready_s2bms.csv") + # df = pd.read_csv("/lustre/backup/SHARED/AIN/aether/data/s2bms/model_ready_s2bms.csv") + df = pd.read_csv("data/s2bms/model_ready_s2bms.csv") # df.sort_values(by="name_loc", inplace=True, ascending=False) - with open(os.path.join("logs", "tessera_skipped.txt")) as f: - skipped = set(f.read().splitlines()) - - df = df[~df.name_loc.isin(skipped)] + if os.path.exists("logs/tessera_skipped.txt"): + with open(os.path.join("logs", "tessera_skipped.txt")) as f: + skipped = set(f.read().splitlines()) + df = df[~df.name_loc.isin(skipped)] + # df.sort_values('name_loc', ascending=False, inplace=True) tessera_from_df( df, - "/lustre/backup/SHARED/AIN/aether/data/s2bms/eo/tessera", + "data/s2bms/eo/tessera", year=2024, tile_size=256, - cache_dir="/lustre/backup/SHARED/AIN/aether/data/cache", + cache_dir="data/cache", ) From 58835932eab29719466479c2c3eaca95a6bcd59d Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 15 Apr 2026 09:16:54 +0200 Subject: [PATCH 75/78] Fix mean fusion strategy --- src/models/components/geo_encoders/encoder_wrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/models/components/geo_encoders/encoder_wrapper.py b/src/models/components/geo_encoders/encoder_wrapper.py index e9d7fff..dbdca85 100644 --- a/src/models/components/geo_encoders/encoder_wrapper.py +++ b/src/models/components/geo_encoders/encoder_wrapper.py @@ -129,7 +129,7 @@ def set_output_dim(self): if self.fusion_strategy == "concat": self.output_dim = sum(output_dims) elif self.fusion_strategy == "mean": - if set(output_dims) != 1: + if len(set(output_dims)) != 1: raise ValueError( f"Encoder branches produces different output dimensions {output_dims} and cannot be averaged." ) @@ -151,7 +151,7 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: if self.extra_projector: feats = self.extra_projector(feats) else: - feats = torch.cat(branch_feats, dim=1) + feats = torch.stack(branch_feats, dim=0).mean(dim=0) if self.extra_projector: feats = self.extra_projector(feats) return feats From fdf9cd4425fa0e36d51baaf3b2fd132dfc4c4f65 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 15 Apr 2026 09:21:44 +0200 Subject: [PATCH 76/78] Geo encoder optional in base model class --- src/models/base_model.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/models/base_model.py b/src/models/base_model.py index e97b53c..d0f92e8 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -16,7 +16,7 @@ class BaseModel(LightningModule, ABC): def __init__( self, trainable_modules: list[str], - geo_encoder: BaseGeoEncoder, + geo_encoder: BaseGeoEncoder | None, text_encoder: BaseTextEncoder | None, prediction_head: BasePredictionHead | None, optimizer: torch.optim.Optimizer | None, @@ -55,7 +55,8 @@ def __init__( ) self.trainable_modules = trainable_modules - self.geo_encoder = geo_encoder + if geo_encoder: + self.geo_encoder = geo_encoder if text_encoder: self.text_encoder = text_encoder if prediction_head: @@ -212,7 +213,8 @@ def configure_optimizers(self) -> Dict[str, Any]: def update_configs(self, cfg): """Update hyper-parameters from the model.""" - self.geo_encoder.update_configs(cfg["geo_encoder"]) + if hasattr(self, "geo_encoder"): + self.geo_encoder.update_configs(cfg["geo_encoder"]) if hasattr(self, "text_encoder"): self.text_encoder.cfg_dict = cfg["text_encoder"] @@ -256,8 +258,8 @@ def on_save_checkpoint(self, checkpoint): } ) - checkpoint["hyper_parameters"]["geo_encoder"] = self.geo_encoder.cfg_dict - + if hasattr(self, "geo_encoder"): + checkpoint["hyper_parameters"]["geo_encoder"] = self.geo_encoder.cfg_dict if hasattr(self, "prediction_head"): checkpoint["hyper_parameters"]["prediction_head"] = self.prediction_head.cfg_dict if hasattr(self, "text_encoder"): From c6b82a57a44ec2c2f88c50b2922cbdbf07c9c766 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 15 Apr 2026 10:33:03 +0200 Subject: [PATCH 77/78] Geo encoder optional in base model class --- src/models/inference_model.py | 48 +++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/src/models/inference_model.py b/src/models/inference_model.py index bd47a6b..5c0023e 100644 --- a/src/models/inference_model.py +++ b/src/models/inference_model.py @@ -22,9 +22,9 @@ class InferenceModel(BaseModel): def __init__( self, - geo_encoder: BaseGeoEncoder, - text_encoder: BaseTextEncoder, - prediction_head: BasePredictionHead, + geo_encoder: BaseGeoEncoder | None, + text_encoder: BaseTextEncoder | None, + prediction_head: BasePredictionHead | None, num_classes: int | None, metrics: MetricsWrapper | None = None, ks: list[int] | None = [5, 10, 15], @@ -68,18 +68,20 @@ def _setup(self, stage: str) -> None: print("-------Model------------") # Configure encoders - self.geo_encoder.setup() - self.text_encoder.setup() - - # Configure optional extra projection so text embeddings match geo embeddings. - if self.text_encoder.output_dim != self.geo_encoder.output_dim: - if self.match_to_geo: - self.text_encoder.add_projector(projected_dim=self.geo_encoder.output_dim) - else: - self.geo_encoder.add_projector(projected_dim=self.text_encoder.output_dim) - + if hasattr(self, "geo_encoder"): + self.geo_encoder.setup() + if hasattr(self, "text_encoder"): + self.text_encoder.setup() + + if hasattr(self, "text_encoder") and hasattr(self, "geo_encoder"): + # Configure optional extra projection so text embeddings match geo embeddings. + if self.text_encoder.output_dim != self.geo_encoder.output_dim: + if self.match_to_geo: + self.text_encoder.add_projector(projected_dim=self.geo_encoder.output_dim) + else: + self.geo_encoder.add_projector(projected_dim=self.text_encoder.output_dim) # Configure prediction head - if self.prediction_head.net is None: + if hasattr(self, "prediction_head") and self.prediction_head.net is None: if self.num_classes is None: raise ValueError( "InferenceModel requires `num_classes` to build the prediction head." @@ -108,14 +110,22 @@ def forward( """Model forward logic.""" # Embed modalities - geo_feats = self.geo_encoder(batch) - text_feats = self.text_encoder(batch, mode) - pred_feats = self.prediction_head(geo_feats) + if hasattr(self, "geo_encoder"): + geo_feats = self.geo_encoder(batch) + if hasattr(self, "text_encoder"): + text_feats = self.text_encoder(batch, mode) + if hasattr(self, "prediction_head"): + pred_feats = self.prediction_head(geo_feats) # Change dtype of geo data if it does not match text dtype - if geo_feats.dtype != text_feats.dtype: + if ( + hasattr(self, "text_encoder") + and hasattr(self, "geo_encoder") + and geo_feats.dtype != text_feats.dtype + ): geo_feats = geo_feats.to(text_feats.dtype) - return pred_feats, geo_feats, text_feats + + return pred_feats, geo_feats, text_feats def concept_similarities(self, geo_embeds, concept=None) -> torch.Tensor: # Get concept embeddings From dda4714050d883ed82caa9118dda0dcec822157b Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 15 Apr 2026 10:55:25 +0200 Subject: [PATCH 78/78] Tessera under "eo" fix and center crops --- src/data/base_dataset.py | 28 +++++++++++++++++++- src/data/butterfly_dataset.py | 3 +-- src/data/heat_guatemala_dataset.py | 10 ++++--- src/data_preprocessing/tessera_data_check.py | 11 ++------ src/utils/data_utils.py | 8 ++++++ 5 files changed, 44 insertions(+), 16 deletions(-) create mode 100644 src/utils/data_utils.py diff --git a/src/data/base_dataset.py b/src/data/base_dataset.py index bd727d7..c1aea8b 100644 --- a/src/data/base_dataset.py +++ b/src/data/base_dataset.py @@ -9,6 +9,7 @@ from torch.utils.data import Dataset import src.data_preprocessing.data_utils as du +from src.utils.data_utils import center_crop_npy class BaseDataset(Dataset, ABC): @@ -237,7 +238,17 @@ def setup_tessera(self) -> None: if fname not in avail_files: print(f"Retrieving missing Tessera data: {fname}") gt = gt or GeoTessera(cache_dir=self.cache_dir) - get_tessera_embeds(rec.lon, rec.lat, rec.name_loc, year, dst_dir, size) + row = self.df[self.df["name_loc"] == rec["name_loc"]] + lon, lat = row.lon.item(), row.lat.item() + get_tessera_embeds( + lon, + lat, + rec["name_loc"], + year=year, + save_dir=dst_dir, + tile_size=size, + tessera_con=gt, + ) @final def setup_aef(self) -> None: @@ -280,6 +291,21 @@ def load_aef(self, filepath: str): """Loads AEF data from file as a tensor.""" im = du.load_tiff(filepath, datatype="np") + size = self.modalities["aef"]["size"] + if im.shape[1] != size: + im = center_crop_npy(im, (64, size, size)) if np.isinf(im).any(): im = np.clip(im, a_min=-0.5, a_max=0.5) + # TODO any normalisation needed + return torch.tensor(im).float() + + @final + def load_tessera(self, filepath: str) -> torch.Tensor: + """Loads.""" + size = self.modalities["tessera"]["size"] + arr = self.load_npy(filepath) + if arr.size()[1] != size: + arr = center_crop_npy(arr, (128, size, size)) + # TODO any normalisation needed + return arr diff --git a/src/data/butterfly_dataset.py b/src/data/butterfly_dataset.py index 5ecacf7..642bf9c 100644 --- a/src/data/butterfly_dataset.py +++ b/src/data/butterfly_dataset.py @@ -154,8 +154,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: formatted_row["eo"][modality] = self.load_s2(row["s2_path"]) # TODO: augmentations elif modality == "tessera": - formatted_row["eo"][modality] = self.load_npy(row["tessera_path"]) - # TODO any normalisation needed + formatted_row["eo"][modality] = self.load_tessera(row["tessera_path"]) elif modality == "aef": formatted_row["eo"][modality] = self.load_aef(row["aef_path"]) diff --git a/src/data/heat_guatemala_dataset.py b/src/data/heat_guatemala_dataset.py index f4fabac..fa66041 100644 --- a/src/data/heat_guatemala_dataset.py +++ b/src/data/heat_guatemala_dataset.py @@ -84,10 +84,12 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: # --- EO modalities --- for modality in self.modalities: - if modality == "coords": - sample["eo"]["coords"] = torch.tensor( - [row["lat"], row["lon"]], dtype=torch.float32 - ) + if modality in ["coords"]: + sample["eo"][modality] = torch.tensor([row["lat"], row["lon"]]) + elif modality == "tessera": + sample["eo"][modality] = self.load_tessera(row["tessera_path"]) + elif modality == "aef": + sample["eo"][modality] = self.load_aef(row["aef_path"]) # --- Tabular features (always included if present in CSV) --- if self.use_features and self.feat_names: diff --git a/src/data_preprocessing/tessera_data_check.py b/src/data_preprocessing/tessera_data_check.py index a7c9801..a6bf246 100644 --- a/src/data_preprocessing/tessera_data_check.py +++ b/src/data_preprocessing/tessera_data_check.py @@ -3,14 +3,7 @@ import numpy as np - -def center_crop(arr, target_shape): - slices = [] - for dim, target in zip(arr.shape, target_shape): - start = (dim - target) // 2 - end = start + target - slices.append(slice(start, end)) - return arr[tuple(slices)] +from src.utils.data_utils import center_crop_npy def main(paths): @@ -23,7 +16,7 @@ def main(paths): p_id = os.path.basename(p).split(".")[0].split("-")[-1] crop = img if s != img.shape[0]: - crop = center_crop(img, (s, s, 128)) + crop = center_crop_npy(img, (s, s, 128)) if crop.shape[0:2] != (s, s): with open(f"logs/tessera_size_mismatch_{s}.txt", "a") as f: diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py new file mode 100644 index 0000000..3d9752e --- /dev/null +++ b/src/utils/data_utils.py @@ -0,0 +1,8 @@ +def center_crop_npy(arr, target_shape): + """Crops npy array to desired size.""" + slices = [] + for dim, target in zip(arr.shape, target_shape): + start = (dim - target) // 2 + end = start + target + slices.append(slice(start, end)) + return arr[tuple(slices)]