diff --git a/configs/eval.yaml b/configs/eval.yaml index be31299..80436cd 100644 --- a/configs/eval.yaml +++ b/configs/eval.yaml @@ -1,18 +1,43 @@ # @package _global_ +# specify here default configuration +# order of defaults determines the order in which configs override each other defaults: - _self_ - - data: mnist # choose datamodule with `test_dataloader()` for evaluation - - model: mnist - - logger: null - - trainer: default - - paths: default + - data: butterfly_coords + - model: predictive_geoclip + - callbacks: default + - logger: ${oc.env:LOGGER,wandb} + - trainer: ${oc.env:TRAINER_PROFILE,default} + - paths: ${oc.env:STORAGE_MODE,local} - extras: default - hydra: default + - metrics: butterfly_predictive + # experiment configs allow for version control of specific hyperparameters + # e.g. best hyperparameters for given model and datamodule + - experiment: null + + # config for hyperparameter optimization + - hparams_search: null + + # optional local config for machine/user specific settings + # it's optional since it doesn't need to exist and is excluded from version control + - optional local: default + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: null + +# task name, determines output directory path task_name: "eval" +# tags to help you identify your experiments +# you can overwrite this in experiment configs +# overwrite from command line with `python train.py tags="[first_tag, second_tag]"` tags: ["dev"] +# seed for random number generators in pytorch, numpy and python.random +seed: 12345 + # passing checkpoint path is necessary for evaluation -ckpt_path: ??? +ckpt_path: aether/logs/train/runs/2026-03-15_13-15-08/checkpoints/epoch_098.ckpt diff --git a/configs/inference.yaml b/configs/inference.yaml new file mode 100644 index 0000000..45f5bcc --- /dev/null +++ b/configs/inference.yaml @@ -0,0 +1,25 @@ +defaults: + - _self_ + - paths: ${oc.env:STORAGE_MODE,local} + - extras: default + - hydra: default + +# task name, determines output directory path +task_name: "inference" +tags: ["dev"] + +# If set, inference.py will load this merged checkpoint directly. +#inference_ckpt_path: ${paths.log_dir}${task_name}/2026-04-07_12-56-16/inference_model.ckpt + +# If `inference_ckpt_path` is not set, stitch the inference model from: +# - predictive ckpt (provides prediction_head weights) +# - alignment ckpt (provides text_encoder weights + geo_encoder) +predictive_ckpt_path: ${paths.log_dir}train/runs/2026-04-02_15-54-53/checkpoints/epoch_000.ckpt +alignment_ckpt_path: ${paths.log_dir}train/runs/2026-04-02_15-40-03/checkpoints/epoch_000.ckpt + +# If set, inference.py will save a merged inference checkpoint you can reload with +# `inference_ckpt_path`. +save_inference_ckpt_path: ${paths.log_dir}${task_name}/${now:%Y-%m-%d}_${now:%H-%M-%S}/inference_model.ckpt +#save_inference_ckpt_path: null + +training_order: ["alignment_model", "prediction_model"] diff --git a/configs/paths/shared.yaml b/configs/paths/shared.yaml index a4702b0..26c1a8c 100644 --- a/configs/paths/shared.yaml +++ b/configs/paths/shared.yaml @@ -9,7 +9,7 @@ data_dir: ${oc.env:DATA_DIR,${oc.env:SHARED_ROOT,${paths.root_dir}}/data} cache_dir: ${oc.env:SHARED_CACHE,${paths.data_dir}/cache} # path to logging directory -log_dir: ${oc.env:SHARED_ROOT}/logs/ +log_dir: ${oc.env:SHARED_ROOT,${paths.root_dir}}/logs/ # path to output directory, created dynamically by hydra # path generation pattern is specified in `configs/hydra/local.yaml` @@ -21,4 +21,4 @@ work_dir: ${hydra:runtime.cwd} # huggingface cache directory # can be overridden via HF_HOME environment variable -huggingface_cache: ${oc.env:HF_HOME,oc.env:SHARED_CACHE/huggingface} +huggingface_cache: ${oc.env:HF_HOME,${oc.env:SHARED_CACHE}/huggingface} diff --git a/notebooks/08-GT-aligment-visualisation.ipynb b/notebooks/08-GT-aligment-visualisation.ipynb new file mode 100644 index 0000000..f8b4e9d --- /dev/null +++ b/notebooks/08-GT-aligment-visualisation.ipynb @@ -0,0 +1,330 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from typing import Any, Dict, List, Tuple\n", + "\n", + "import hydra\n", + "import torch\n", + "from lightning import LightningDataModule, LightningModule, Trainer\n", + "from lightning.pytorch.loggers import Logger\n", + "\n", + "if os.environ.get(\"TOKENIZERS_PARALLELISM\") is None:\n", + " os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "\n", + "from src.utils import (\n", + " instantiate_loggers,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "os.chdir(\"..\")\n", + "os.getcwd()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "from hydra import compose, initialize\n", + "from hydra.core.hydra_config import HydraConfig\n", + "\n", + "with initialize(version_base=\"1.3\", config_path=\"../configs\"):\n", + " cfg = compose(\n", + " config_name=\"eval.yaml\",\n", + " return_hydra_config=True,\n", + " overrides=[\"experiment=alignment_v1\"], # 👈 loads configs/experiment/my_exp.yaml\n", + " )\n", + "\n", + "HydraConfig.instance().set_config(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "assert cfg.ckpt_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)\n", + "\n", + "model: LightningModule = hydra.utils.instantiate(cfg.model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "logger: List[Logger] = instantiate_loggers(cfg.get(\"logger\"))\n", + "\n", + "cfg[\"trainer\"][\"max_epochs\"] = 0\n", + "cfg[\"trainer\"][\"accelerator\"] = \"mps\"\n", + "trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)\n", + "\n", + "object_dict = {\n", + " \"cfg\": cfg,\n", + " \"datamodule\": datamodule,\n", + " \"model\": model,\n", + " \"logger\": logger,\n", + " \"trainer\": trainer,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.strategy.connect(model)\n", + "trainer._data_connector.attach_data(model, datamodule=datamodule)\n", + "\n", + "model.setup(stage=\"fit\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "path = \"logs/train/runs/2026-03-15_13-15-08/checkpoints/epoch_098.ckpt\"\n", + "ckpt = torch.load(path, map_location=\"cpu\", weights_only=False)\n", + "model.load_state_dict(ckpt[\"state_dict\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "data = datamodule.test_dataloader().dataset[2]\n", + "data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "\n", + "def show_rgb(tensor, stretch=True):\n", + " \"\"\"\n", + " tensor: numpy array or torch tensor of shape (C, H, W)\n", + " bands: tuple of 3 indices (R, G, B)\n", + " stretch: whether to normalize values for display\n", + " \"\"\"\n", + " # Convert to numpy if it's a torch tensor\n", + " if \"torch\" in str(type(tensor)):\n", + " tensor = tensor.detach().cpu().numpy()\n", + "\n", + " # Select bands\n", + " rgb = tensor[:3, :, :]\n", + "\n", + " # Move to (H, W, 3)\n", + " rgb = np.transpose(rgb, (1, 2, 0))\n", + "\n", + " # Normalize for visualization\n", + " if stretch:\n", + " rgb = rgb.astype(np.float32)\n", + " for i in range(3):\n", + " band = rgb[:, :, i]\n", + " band_min, band_max = band.min(), band.max()\n", + " if band_max > band_min:\n", + " rgb[:, :, i] = (band - band_min) / (band_max - band_min)\n", + "\n", + " plt.imshow(rgb)\n", + " plt.axis(\"off\")\n", + " plt.show()\n", + "\n", + "\n", + "show_rgb(data[\"eo\"][\"aef\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "batch = data\n", + "batch[\"eo\"][\"aef\"] = batch[\"eo\"][\"aef\"].unsqueeze(0)\n", + "batch[\"eo\"][\"aef\"].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "geo_feats = model.geo_encoder(data).squeeze(1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "concepts = [\n", + " \"Densely populated area with many houses\",\n", + " \"Very sparsely populated area with few houses\",\n", + " \"Arable land with crops for agriculture\",\n", + " \"Pasture fields with grass for grazing animals\",\n", + " \"Forested area with many trees\",\n", + " \"Water bodies such as lakes, rivers and sea\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "text_tokens = model.text_encoder.processor(\n", + " text=concepts,\n", + " return_tensors=\"pt\",\n", + " padding=True,\n", + " truncation=True,\n", + " max_length=77,\n", + ")\n", + "\n", + "device = model.text_encoder.device\n", + "text_tokens = {k: v.to(device) for k, v in text_tokens.items()}\n", + "\n", + "text_embeds = model.text_encoder.model.get_text_features(**text_tokens)\n", + "\n", + "# Project\n", + "if model.text_encoder.projector is not None:\n", + " text_embeds = model.text_encoder.projector(text_embeds)\n", + "\n", + "if model.text_encoder.extra_projector is not None:\n", + " text_embeds = model.text_encoder.extra_projector(text_embeds)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "text_embeds.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "def cosine_scores(x, Y):\n", + " \"\"\"\n", + " x: tensor of shape (1, 512)\n", + " Y: tensor of shape (6, 512)\n", + " returns: tensor of shape (6,)\n", + " \"\"\"\n", + " # Normalize\n", + " x = F.normalize(x, dim=1) # (1, 512)\n", + " Y = F.normalize(Y, dim=1) # (6, 512)\n", + "\n", + " # Compute cosine similarity\n", + " scores = torch.matmul(Y, x.T) # (6, 1)\n", + "\n", + " return scores.squeeze(1)\n", + "\n", + "\n", + "cosine_scores(geo_feats, text_embeds.cpu())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "scores = {\n", + " k: round(v.detach().item(), 2)\n", + " for k, v in zip(concepts, cosine_scores(geo_feats, text_embeds.cpu()))\n", + "}\n", + "scores" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/data/base_dataset.py b/src/data/base_dataset.py index bd727d7..c1aea8b 100644 --- a/src/data/base_dataset.py +++ b/src/data/base_dataset.py @@ -9,6 +9,7 @@ from torch.utils.data import Dataset import src.data_preprocessing.data_utils as du +from src.utils.data_utils import center_crop_npy class BaseDataset(Dataset, ABC): @@ -237,7 +238,17 @@ def setup_tessera(self) -> None: if fname not in avail_files: print(f"Retrieving missing Tessera data: {fname}") gt = gt or GeoTessera(cache_dir=self.cache_dir) - get_tessera_embeds(rec.lon, rec.lat, rec.name_loc, year, dst_dir, size) + row = self.df[self.df["name_loc"] == rec["name_loc"]] + lon, lat = row.lon.item(), row.lat.item() + get_tessera_embeds( + lon, + lat, + rec["name_loc"], + year=year, + save_dir=dst_dir, + tile_size=size, + tessera_con=gt, + ) @final def setup_aef(self) -> None: @@ -280,6 +291,21 @@ def load_aef(self, filepath: str): """Loads AEF data from file as a tensor.""" im = du.load_tiff(filepath, datatype="np") + size = self.modalities["aef"]["size"] + if im.shape[1] != size: + im = center_crop_npy(im, (64, size, size)) if np.isinf(im).any(): im = np.clip(im, a_min=-0.5, a_max=0.5) + # TODO any normalisation needed + return torch.tensor(im).float() + + @final + def load_tessera(self, filepath: str) -> torch.Tensor: + """Loads.""" + size = self.modalities["tessera"]["size"] + arr = self.load_npy(filepath) + if arr.size()[1] != size: + arr = center_crop_npy(arr, (128, size, size)) + # TODO any normalisation needed + return arr diff --git a/src/data/butterfly_dataset.py b/src/data/butterfly_dataset.py index 5ecacf7..642bf9c 100644 --- a/src/data/butterfly_dataset.py +++ b/src/data/butterfly_dataset.py @@ -154,8 +154,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: formatted_row["eo"][modality] = self.load_s2(row["s2_path"]) # TODO: augmentations elif modality == "tessera": - formatted_row["eo"][modality] = self.load_npy(row["tessera_path"]) - # TODO any normalisation needed + formatted_row["eo"][modality] = self.load_tessera(row["tessera_path"]) elif modality == "aef": formatted_row["eo"][modality] = self.load_aef(row["aef_path"]) diff --git a/src/data/heat_guatemala_dataset.py b/src/data/heat_guatemala_dataset.py index f4fabac..fa66041 100644 --- a/src/data/heat_guatemala_dataset.py +++ b/src/data/heat_guatemala_dataset.py @@ -84,10 +84,12 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: # --- EO modalities --- for modality in self.modalities: - if modality == "coords": - sample["eo"]["coords"] = torch.tensor( - [row["lat"], row["lon"]], dtype=torch.float32 - ) + if modality in ["coords"]: + sample["eo"][modality] = torch.tensor([row["lat"], row["lon"]]) + elif modality == "tessera": + sample["eo"][modality] = self.load_tessera(row["tessera_path"]) + elif modality == "aef": + sample["eo"][modality] = self.load_aef(row["aef_path"]) # --- Tabular features (always included if present in CSV) --- if self.use_features and self.feat_names: diff --git a/src/data_preprocessing/tessera_data_check.py b/src/data_preprocessing/tessera_data_check.py new file mode 100644 index 0000000..a6bf246 --- /dev/null +++ b/src/data_preprocessing/tessera_data_check.py @@ -0,0 +1,50 @@ +import glob +import os + +import numpy as np + +from src.utils.data_utils import center_crop_npy + + +def main(paths): + + sizes = [256, 128] + + for p in paths: + img = np.load(p) + for s in sizes: + p_id = os.path.basename(p).split(".")[0].split("-")[-1] + crop = img + if s != img.shape[0]: + crop = center_crop_npy(img, (s, s, 128)) + + if crop.shape[0:2] != (s, s): + with open(f"logs/tessera_size_mismatch_{s}.txt", "a") as f: + print(f"{p_id} has shape {crop.shape[0:2]}") + f.write(f"{p_id}\n") + + if np.isinf(crop.any()): + with open(f"logs/tessera_nans_{s}.txt", "a") as f: + f.write(f"{p_id}\n") + + nulls = np.sum(crop == 0) + if nulls > (s * s * 128) * 0.5: + with open(f"logs/tessera_50per_empty_{s}.txt", "a") as f: + print(f"50% of {p_id} is 0") + f.write(f"{p_id}\n") + + if nulls > (s * s * 128) * 0.25: + with open(f"logs/tessera_25per_empty_{s}.txt", "a") as f: + print(f"25% of {p_id} is 0") + f.write(f"{p_id}\n") + + +if __name__ == "__main__": + os.chdir("../..") + print(os.getcwd()) + + paths = glob.glob("/lustre/backup/SHARED/AIN/aether/data/s2bms/eo/tessera/*.npy") + paths = glob.glob("data/s2bms/eo/tessera/*.npy") + paths.sort() + + main(paths) diff --git a/src/data_preprocessing/tessera_embeds.py b/src/data_preprocessing/tessera_embeds.py index c2414b8..2774687 100644 --- a/src/data_preprocessing/tessera_embeds.py +++ b/src/data_preprocessing/tessera_embeds.py @@ -1,8 +1,10 @@ +import concurrent.futures import math import os import threading import numpy as np +from affine import Affine # Serialises concurrent reads/writes to the per-directory meta.csv log file. _meta_csv_lock = threading.Lock() @@ -16,13 +18,24 @@ from rasterio.warp import Resampling, calculate_default_transform, reproject from src.data_preprocessing.crs_utils import ( - create_bbox_with_radius, crs_to_pixel_coords, get_point_utm_crs, point_reprojection, ) +class PartialTileError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class NoTileError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + def reproject_dataset(src_raster: MemoryFile, dst_crs: str) -> MemoryFile: """Reprojects Memory file if it's not in dst_crs. @@ -62,6 +75,37 @@ def reproject_dataset(src_raster: MemoryFile, dst_crs: str) -> MemoryFile: return dst, memfile +def get_tiles(lat_center, lon_center, half_size_m=75, year=2024): + """Find all 0.1deg tiles (referenced at 0.05deg) that overlap AOI.""" + + # Convert half-size from meters to degrees (approximate) + half_lat_deg = half_size_m / 111320.0 + half_lon_deg = half_size_m / (111320.0 * np.cos(np.radians(lat_center))) + + # AOI bounds + lat_min = lat_center - half_lat_deg + lat_max = lat_center + half_lat_deg + lon_min = lon_center - half_lon_deg + lon_max = lon_center + half_lon_deg + + tile_size = 0.1 + + # Find tile indices overlapping the AOI + i_min = int(np.floor(lat_min / tile_size)) + i_max = int(np.floor(lat_max / tile_size)) + j_min = int(np.floor(lon_min / tile_size)) + j_max = int(np.floor(lon_max / tile_size)) + + tiles = [] + for i in range(i_min, i_max + 1): + for j in range(j_min, j_max + 1): + ref_lat = i * tile_size + 0.05 # tile reference (center) + ref_lon = j * tile_size + 0.05 + tiles.append((year, round(ref_lon, 10), round(ref_lat, 10))) + + return tiles + + def get_tessera_embeds( lon: float, lat: float, @@ -70,6 +114,7 @@ def get_tessera_embeds( save_dir: str, tile_size: int, tessera_con: GeoTessera | None, + padding: int = 100, ) -> None: """Obtain tessera embedding tile with specified size for a given coordinates. @@ -80,9 +125,11 @@ def get_tessera_embeds( :param save_dir: data directory to save embeddings :param tile_size: tile size in pixels :param tessera_con: GeoTessera instance + :param padding: how many meters to pad initial bbox, fixes some inconsistencies when mosaicing :return: None """ + # Skip if tile exists embed_tile_name = os.path.join(save_dir, f"tessera_{name_loc}.npy") if os.path.exists(embed_tile_name): return @@ -91,13 +138,10 @@ def get_tessera_embeds( utm_crs = get_point_utm_crs(lon, lat) lon_utm, lat_utm = point_reprojection(lon, lat, "EPSG:4326", utm_crs) - # Bounding box - radius = math.ceil(tile_size / 2) + 10 - bbox = create_bbox_with_radius(lon, lat, radius=radius, utm_crs=utm_crs, return_wgs=True) - # Request to tessera - tiles_to_fetch = tessera_con.registry.load_blocks_for_region( - bounds=bbox.bounds, year=int(year) + radius = math.ceil(tile_size / 2) + padding + tiles_to_fetch = get_tiles( + lat_center=lat, lon_center=lon, half_size_m=radius * 10, year=int(year) ) # Mosaic returned tiles for the bbox @@ -126,13 +170,10 @@ def get_tessera_embeds( if reproject_memfile: memfiles.append(reproject_memfile) - if not tiles: - print( - f"No TESSERA tiles found for {name_loc} at ({lon:.4f}, {lat:.4f}) year={year}. Skipping." - ) + if len(tiles) == 0: for mf in memfiles: mf.close() - return + raise NoTileError(f"No tiles found for {name_loc}") # if no tiles, add to skipped.txt mosaic, mosaic_transform = merge(tiles) mosaic = mosaic.transpose(1, 2, 0) @@ -143,26 +184,54 @@ def get_tessera_embeds( mf.close() # Crop patch tile - col, row = crs_to_pixel_coords(lon_utm, lat_utm, mosaic_transform) + c, r = crs_to_pixel_coords(lon_utm, lat_utm, mosaic_transform) half = tile_size // 2 - row_min = row - half - row_max = row + tile_size - half # tile_size - half ensures correct size for odd tile_size - col_min = col - half - col_max = col + tile_size - half + row_min = r - half + row_max = r + half + col_min = c - half + col_max = c + half + + if row_min < 0 or row_max < 0 or col_min < 0 or col_max < 0: + # retry with bigger padding + if padding > 500: + raise NoTileError(f"Padding {padding} > 500") + get_tessera_embeds( + lon, lat, name_loc, year, save_dir, tile_size, tessera_con, padding=padding + 100 + ) + crop = mosaic[row_min:row_max, col_min:col_max, :] + if not crop.shape == (tile_size, tile_size, 128): + if crop.min() == 0.0 and crop.max() == 0.0: + raise NoTileError(f"No tiles found for {name_loc}") + raise PartialTileError(f"Crop {name_loc}, size is {crop.shape}") - if crop.shape[0] != tile_size or crop.shape[1] != tile_size: - print( - f"Unexpected crop shape {crop.shape} for {name_loc} " - f"(expected {tile_size}x{tile_size}). Skipping." - ) - return + if crop.min() == 0.0 and crop.max() == 0.0: + raise NoTileError(f"Crop {name_loc} has embeddings of 0.0s with tiles: {tiles_to_fetch}") # Save array os.makedirs(save_dir, exist_ok=True) np.save(embed_tile_name, crop) print(f"Array saved as {embed_tile_name}") + # Temp save tif too + crop_transform = mosaic_transform * Affine.translation(col_min, row_min) + height, width, channels = crop.shape + + with rasterio.open( + embed_tile_name[:-4] + ".tif", + "w", + driver="GTiff", + height=height, + width=width, + count=channels, + dtype=crop.dtype, + crs=utm_crs, + transform=crop_transform, + ) as dst: + for i in range(channels): + dst.write(crop[:, :, i], i + 1) + print(f"tif saved to {embed_tile_name[:-4]}.tif") + # Log its metadata meta_df = pd.DataFrame( {"id": [name_loc], "year": [year], "lon": [lon], "lat": [lat], "crs": [utm_crs]} @@ -186,6 +255,7 @@ def tessera_from_df( year: int, tile_size: int = 256, cache_dir: str = "temp/", + logs_dir: str = "logs", ) -> None: """Obtains Tessera embeddings from a CSV file for each (lon, lat). @@ -205,8 +275,15 @@ def tessera_from_df( n = len(model_ready_df) for i, row in model_ready_df.iterrows(): print(f"{i}/{n}") - # Get tessera embeds - get_tessera_embeds(row.lon, row.lat, row.name_loc, year, f"{data_dir}/", tile_size, gt) + try: + get_tessera_embeds(row.lon, row.lat, row.name_loc, year, f"{data_dir}/", tile_size, gt) + except Exception as e: + if isinstance(e, NoTileError): + path = os.path.join(logs_dir, "tessera_skipped.txt") + with open(path, "a") as f: + f.write(f"{row.name_loc}\n") + else: + print(f"{row.name_loc} did not get embedded because: {e}") def inspect_np_arr_as_tiff( @@ -270,10 +347,24 @@ def inspect_np_arr_as_tiff( if __name__ == "__main__": - os.chdir("../..") + # os.chdir('../..') + + print(os.getcwd()) - df = pd.read_csv("data/heat_guatemala/model_ready_heat_guatemala.csv") + # df = pd.read_csv("data/heat_guatemala/model_ready_heat_guatemala.csv") + # df = pd.read_csv("/lustre/backup/SHARED/AIN/aether/data/s2bms/model_ready_s2bms.csv") + df = pd.read_csv("data/s2bms/model_ready_s2bms.csv") + # df.sort_values(by="name_loc", inplace=True, ascending=False) + if os.path.exists("logs/tessera_skipped.txt"): + with open(os.path.join("logs", "tessera_skipped.txt")) as f: + skipped = set(f.read().splitlines()) + df = df[~df.name_loc.isin(skipped)] + # df.sort_values('name_loc', ascending=False, inplace=True) tessera_from_df( - df, "data/heat_guatemala/eo/tessera_2024", year=2024, tile_size=10, cache_dir="data/cache" + df, + "data/s2bms/eo/tessera", + year=2024, + tile_size=256, + cache_dir="data/cache", ) diff --git a/src/data_preprocessing/yield_africa_spatial_splits.py b/src/data_preprocessing/yield_africa_spatial_splits.py index 4c9aa98..d71b15b 100644 --- a/src/data_preprocessing/yield_africa_spatial_splits.py +++ b/src/data_preprocessing/yield_africa_spatial_splits.py @@ -79,13 +79,12 @@ def make_spatial_split( """Return a split-indices dict using DBSCAN spatial clustering. :param df: full model-ready dataframe (must contain 'lat', 'lon', 'name_loc') - :param distance_m: DBSCAN eps in metres — pairs of fields closer than this - value are assigned to the same cluster + :param distance_m: DBSCAN eps in metres — pairs of fields closer than this value are assigned + to the same cluster :param train_val_test_split: (train, val, test) proportions, must sum to 1.0 :param seed: random seed for GroupShuffleSplit - :return: dict with 'train_indices', 'val_indices', 'test_indices' as - pd.Series of name_loc strings, plus 'clusters' as a numpy array of - cluster labels (same length as df) + :return: dict with 'train_indices', 'val_indices', 'test_indices' as pd.Series of name_loc + strings, plus 'clusters' as a numpy array of cluster labels (same length as df) """ # Deduplicate to unique (lat, lon) locations before clustering. # yield_africa has ~9 rows per location (one per year); running DBSCAN on all @@ -271,9 +270,7 @@ def generate_splits( f"(train={n_train}, val={n_val}, test={n_test}, " f"total={n_train + n_val + n_test}/{len(df)})" ) - log.info( - f" {dist_km}km: train={n_train}, val={n_val}, test={n_test} -> {out_name}" - ) + log.info(f" {dist_km}km: train={n_train}, val={n_val}, test={n_test} -> {out_name}") def main() -> None: diff --git a/src/inference.py b/src/inference.py new file mode 100644 index 0000000..dc15b50 --- /dev/null +++ b/src/inference.py @@ -0,0 +1,46 @@ +from typing import Optional + +import hydra +import rootutils +from dotenv import load_dotenv +from omegaconf import DictConfig + +from src.models.inference_model import load_inference_model, merge_inference_model +from src.utils import extras + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +load_dotenv() + +# Disable tokenizers parallelism to avoid warnings when using multiprocessing +import os + +if os.environ.get("TOKENIZERS_PARALLELISM") is None: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="inference.yaml") +def main(cfg: DictConfig) -> Optional[float]: + """Main entry point for training. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Optional[float] with optimized metric value. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + # If a merged inference ckpt is provided, just load it. + inference_ckpt_path = cfg.get("inference_ckpt_path") + if inference_ckpt_path: + model = load_inference_model(inference_ckpt_path) + # Otherwise merge model from two checkpoints + else: + model = merge_inference_model(cfg, save_ckpt=True) + + # TODO: do what you need with the inference model + + return + + +if __name__ == "__main__": + main() diff --git a/src/models/base_model.py b/src/models/base_model.py index 89fbaf6..d0f92e8 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -4,58 +4,127 @@ import torch from lightning import LightningModule +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder from src.models.components.loss_fns.base_loss_fn import BaseLossFn from src.models.components.metrics.metrics_wrapper import MetricsWrapper +from src.models.components.pred_heads.base_pred_head import BasePredictionHead +from src.models.components.text_encoders.base_text_encoder import BaseTextEncoder +from src.utils.logging_utils import log_model_loading class BaseModel(LightningModule, ABC): def __init__( self, - trainable_modules: list[str] | None, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler, - loss_fn: BaseLossFn, - metrics: MetricsWrapper, + trainable_modules: list[str], + geo_encoder: BaseGeoEncoder | None, + text_encoder: BaseTextEncoder | None, + prediction_head: BasePredictionHead | None, + optimizer: torch.optim.Optimizer | None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None, + loss_fn: BaseLossFn | None, + metrics: MetricsWrapper | None, + num_classes: int | None, + tabular_dim: int | None, ) -> None: """Interface for any model. :param trainable_modules: which modules to train + :param geo_encoder: module for encoding geo data + :param text_encoder: module for encoding text data + :param prediction_head: module for making prediction from geo features :param optimizer: optimizer for the model weight update :param scheduler: scheduler for the model weight update :param loss_fn: loss function :param metrics: metrics to track for model performance estimation + :param num_classes: number of target classes + :param tabular_dim: number of tabular features """ super().__init__() + + # Ignore objects self.save_hyperparameters( - ignore=["loss_fn", "geo_encoder", "prediction_head", "text_encoder", "metrics"] + ignore=[ + "geo_encoder", + "text_encoder", + "prediction_head", + "optimizer", + "scheduler", + "loss_fn", + "metrics", + ] ) self.trainable_modules = trainable_modules - self.num_classes: int | None = None - self.tabular_dim: int | None = None + if geo_encoder: + self.geo_encoder = geo_encoder + if text_encoder: + self.text_encoder = text_encoder + if prediction_head: + self.prediction_head = prediction_head + + self.optimizer = optimizer + self.scheduler = scheduler + self.loss_fn = loss_fn self.metrics = metrics - @abstractmethod + self.num_classes = num_classes + self.tabular_dim = tabular_dim + + self.setup_flag = False + + @final def setup(self, stage: str) -> None: """Updates model based data-bound configurations (through datamodule), This method is called after trainer is initialized and datamodule is available.""" + if self.setup_flag: + print(f"Model {self.__str__()} is already set up!") + return + + # If trainer is attached get num_classes and tabular_dim from datamodule (data-dependent) + if self._trainer is not None: + self.num_classes = self.trainer.datamodule.num_classes + self.tabular_dim = self.trainer.datamodule.tabular_dim + + # Per model logic of setting up + self._setup(stage) + self.setup_flag = True + + # Freezing requested parts + if stage in ["inference", "test"]: + self.full_freezer() + self.trainable_modules = [] + else: + self.freezer() + + @abstractmethod + def _setup(self, stage: str) -> None: pass + @final + def full_freezer(self): + """Freeze the whole network.""" + for name, param in self.named_parameters(): + param.requires_grad = False + + for name, module in self.named_modules(): + module.eval() + + return + @final def freezer(self) -> None: """Freezes modules based on provided trainable modules.""" - trainable_modules = tuple(self.trainable_modules) or tuple() - - # Store higher level module names for printing of trainable parts - trainable = set() + # Convert for checking with .startswith() + trainable_set = tuple(set(self.trainable_modules)) or tuple() + expanded_trainable = set() # Freeze modules for name, param in self.named_parameters(): - # Enable exceptions - if name.startswith(trainable_modules): + # Unfreeze trainable parts + if name.startswith(trainable_set): param.requires_grad = True - trainable.add(name) + expanded_trainable.add(name) else: # Freeze the rest param.requires_grad = False @@ -69,8 +138,8 @@ def freezer(self) -> None: # - it is the root module (""), which must be train when any child is. def _in_train_scope(name: str) -> bool: if not name: # root module - return bool(trainable_modules) - for t in trainable_modules: + return bool(trainable_set) + for t in trainable_set: if name == t or name.startswith(t + ".") or t.startswith(name + "."): return True return False @@ -81,11 +150,11 @@ def _in_train_scope(name: str) -> bool: else: module.eval() - print("----------------------------") - print("Set to train") - for m in sorted(trainable): + print("------Set to train------") + for m in sorted(expanded_trainable): print(f" {m}") - print("----------------------------") + print("------------------------") + self.trainable_modules = list(expanded_trainable) @abstractmethod def forward( @@ -127,10 +196,10 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Ten def configure_optimizers(self) -> Dict[str, Any]: """Configure optimizer and learning rate scheduler.""" - optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + optimizer = self.optimizer(params=self.trainer.model.parameters()) - if self.hparams.scheduler is not None: - scheduler = self.hparams.scheduler(optimizer=optimizer) + if self.scheduler is not None: + scheduler = self.scheduler(optimizer=optimizer) return { "optimizer": optimizer, "lr_scheduler": { @@ -142,27 +211,67 @@ def configure_optimizers(self) -> Dict[str, Any]: } return {"optimizer": optimizer} + def update_configs(self, cfg): + """Update hyper-parameters from the model.""" + if hasattr(self, "geo_encoder"): + self.geo_encoder.update_configs(cfg["geo_encoder"]) + + if hasattr(self, "text_encoder"): + self.text_encoder.cfg_dict = cfg["text_encoder"] + + if hasattr(self, "prediction_head"): + self.prediction_head.cfg_dict = cfg["prediction_head"] + def on_save_checkpoint(self, checkpoint): - """Save only trainable parts of the model.""" + """Save checkpoint. + + - Save only trainable parts of the model. + - Append configurations of the model + """ + if not self.setup_flag: + raise ValueError("Model cannot be saved as it was not set up.") + + # Remove unnecessary keys + pop_list = [ + "state_dict", + "loops", + "hparams_name", + "datamodule_hyper_parameters", + "datamodule_hparams_name", + ] + for i in pop_list: + checkpoint.pop(i) + + # Save only trainable parts checkpoint["state_dict"] = { k: v for k, v in self.state_dict().items() if any(k.startswith(part) for part in self.trainable_modules) } - def on_load_checkpoint(self, checkpoint): - """Load only trainable parts of the model.""" - missing_keys, unexpected_keys = self.load_state_dict( - checkpoint["state_dict"], strict=False + # Update model configurations + checkpoint["hyper_parameters"].update( + { + "num_classes": self.num_classes, + "tabular_dim": self.tabular_dim, + "trainable_modules": self.trainable_modules, + } ) - print("Model loaded from a checkpoint.") - if missing_keys: - missing_keys = {".".join(i.split(".")[:3]) for i in missing_keys} - print(f"The following keys are missing from the pretrained model: {missing_keys}") - if unexpected_keys: - unexpected_keys = {".".join(i.split(".")[:3]) for i in unexpected_keys} - print(f"The following keys are unexpected from the pretrained model:{unexpected_keys}") + if hasattr(self, "geo_encoder"): + checkpoint["hyper_parameters"]["geo_encoder"] = self.geo_encoder.cfg_dict + if hasattr(self, "prediction_head"): + checkpoint["hyper_parameters"]["prediction_head"] = self.prediction_head.cfg_dict + if hasattr(self, "text_encoder"): + checkpoint["hyper_parameters"]["text_encoder"] = self.text_encoder.cfg_dict + + return + + def on_load_checkpoint(self, checkpoint): + """Load pre-trained parts of the model.""" + res = self.load_state_dict(checkpoint["state_dict"], strict=False) + print("Model loaded from a checkpoint.") + log_model_loading("Model from checkpoint", res) # TODO feels illegal def load_state_dict(self, state_dict, strict=True): diff --git a/src/models/components/geo_encoders/adopt_encoder.py b/src/models/components/geo_encoders/adopt_encoder.py new file mode 100644 index 0000000..2a5bd88 --- /dev/null +++ b/src/models/components/geo_encoders/adopt_encoder.py @@ -0,0 +1,36 @@ +from typing import Dict, List, override + +import hydra +import torch + +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder +from src.utils.logging_utils import log_model_loading + + +def adopt_encoder(ckpt_path: str) -> BaseGeoEncoder: + """Return geo_encoder from a provided checkpoint. + + :param ckpt_path: path to checkpoint file + :return: trained geo encoder + """ + # Get checkpoint + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) + + # Get skeleton + geo_config = ckpt["hyper_parameters"].get("geo_encoder") + encoder: BaseGeoEncoder = hydra.utils.instantiate(geo_config) + print("---Adopted encoder------") + encoder.setup() + encoder.cfg_dict = geo_config + print("------------------------") + + # Load in the weights + state_dict = { + k.replace("geo_encoder.", ""): v + for k, v in ckpt["state_dict"].items() + if "geo_encoder." in k + } + res = encoder.load_state_dict(state_dict, strict=False) + log_model_loading("geo_encoder_ckpt", res) + + return encoder diff --git a/src/models/components/geo_encoders/average_encoder.py b/src/models/components/geo_encoders/average_encoder.py index 0c0eaf6..7e31b44 100644 --- a/src/models/components/geo_encoders/average_encoder.py +++ b/src/models/components/geo_encoders/average_encoder.py @@ -19,23 +19,17 @@ def __init__( self.dict_n_bands_default = {"s2": 4, "aef": 64, "tessera": 128} self.allowed_geo_data_names: list[str] = list(self.dict_n_bands_default.keys()) - assert ( geo_data_name in self.allowed_geo_data_names ), f"geo_data_name must be one of {self.allowed_geo_data_names}, got {geo_data_name}" self.geo_data_name = geo_data_name @override - def setup(self) -> List[str]: - """Configures networks, data-dependent parts. + def _setup(self) -> List[str]: + """Configures modules and returns newly initialised, trainable module names.""" - Gets called in model.setup() method. Returns names of any new module configured to be added - to the trainable modules list. - """ self.output_dim = self.dict_n_bands_default[self.geo_data_name] self.geo_encoder = nn.Identity() - print(f"Model set up with average geo-encoder for {self.geo_data_name}") - return [] @override @@ -43,4 +37,10 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """Data forward pass through the encoder.""" tile = batch.get("eo", {}).get(self.geo_data_name) feats = self.geo_encoder(tile.mean(dim=(-2, -1))) + if self.extra_projector: + feats = self.extra_projector(feats) return feats + + @property + def device(self): + return diff --git a/src/models/components/geo_encoders/base_geo_encoder.py b/src/models/components/geo_encoders/base_geo_encoder.py index f79f3f4..167e53f 100644 --- a/src/models/components/geo_encoders/base_geo_encoder.py +++ b/src/models/components/geo_encoders/base_geo_encoder.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict +from typing import Dict, List, final import torch from torch import nn @@ -8,13 +8,45 @@ class BaseGeoEncoder(nn.Module, ABC): def __init__(self) -> None: super().__init__() + + # Modules self.geo_encoder: nn.Module | None = None + self.extra_projector: nn.Module | None = None + self.output_dim: int | None = None + self.setup_flag: bool = False + self.cfg_dict: Dict = {} - # placeholders self.allowed_geo_data_names: list[str] | None = None self.geo_data_name: str | None = None + def update_configs(self, cfg): + if len(self.cfg_dict) == 0: + self.cfg_dict = cfg + else: + print("Configs for geo encoder not updated") + + @final + def setup(self) -> List[str]: + """Configures modules. + + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. + """ + if self.setup_flag: + print(f"Module {self.__str__()} is already set up.") + return [] + else: + trainable_modules = self._setup() + print(f"Model set up with {self.__str__()}") + self.setup_flag = True + return trainable_modules + + @abstractmethod + def _setup(self) -> List[str]: + """Configures modules and returns newly initialised, trainable module names.""" + pass + @abstractmethod def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: pass @@ -37,15 +69,7 @@ def dtype(self) -> torch.dtype | None: return None return dtypes.pop() - @abstractmethod - def setup(self) -> list[str]: - """Configures networks, data-dependent parts. - - Gets called in model.setup() method. Returns names of any new module configured to be added - to the trainable modules list. - """ - pass - + @final def add_projector(self, projected_dim: int) -> None: """Adds an extra linear projection layer to the geo encoder. diff --git a/src/models/components/geo_encoders/cnn_encoder.py b/src/models/components/geo_encoders/cnn_encoder.py index 34f8f48..4ece11d 100644 --- a/src/models/components/geo_encoders/cnn_encoder.py +++ b/src/models/components/geo_encoders/cnn_encoder.py @@ -45,8 +45,6 @@ def __init__( ), f"input_n_bands must be int >=3, got {self.input_n_bands}" self.output_dim = output_dim - self.geo_encoder = self.get_backbone() - def set_n_input_bands(self, n_bands: int | None = None) -> None: """Sets number of input bands based on geo_data_name if n_bands is None. @@ -132,10 +130,9 @@ def get_backbone(self): raise ValueError(f"Unsupported backbone: {self.backbone}") @override - def setup(self) -> List[str]: + def _setup(self) -> List[str]: # TODO: could you make sure new layers are returned here to be added to trainable parts? - # Maybe move the get_backbone method in here? - print(f"Model setup with cnn geo-encoder for {self.geo_data_name}") + self.geo_encoder = self.get_backbone() return [] @override @@ -159,8 +156,7 @@ def forward( # n_nans == 0 # ), f"CNNEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.geo_data_name].min()} and max {eo_data[self.geo_data_name].max()}." - return feats.to(dtype) - + if self.extra_projector: + feats = self.extra_projector(feats) -if __name__ == "__main__": - _ = CNNEncoder(None, None, None, None, None, None, None, None) + return feats.to(dtype) diff --git a/src/models/components/geo_encoders/encoder_wrapper.py b/src/models/components/geo_encoders/encoder_wrapper.py index 9cdf8ad..dbdca85 100644 --- a/src/models/components/geo_encoders/encoder_wrapper.py +++ b/src/models/components/geo_encoders/encoder_wrapper.py @@ -4,7 +4,9 @@ import torch.nn as nn from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder +from src.models.components.geo_encoders.identity_encoder import IdentityEncoder from src.models.components.geo_encoders.tabular_encoder import TabularEncoder +from src.utils.errors import IllegalArgumentCombination class EncoderWrapper(BaseGeoEncoder): @@ -39,7 +41,26 @@ def _reformat_set_branches(self, encoder_branches: List[Dict[str, Any]]): self.encoder_branches.append(module_dict) @override - def setup(self) -> List[str]: + def update_configs(self, cfg): + """Update model configurations.""" + # If adopted encoder -> it should already saved the configs + if ( + cfg["_target_"] == "src.models.components.geo_encoders.adopt_encoder.adopt_encoder" + and len(self.cfg_dict) != 0 + ): + return + + for i, branch in enumerate(cfg["encoder_branches"]): + if ( + branch["encoder"]["_target_"] + == "src.models.components.geo_encoders.adopt_encoder.adopt_encoder" + ): + branch["encoder"] = self.encoder_branches[i]["encoder"].cfg_dict + + self.cfg_dict = cfg + + @override + def _setup(self) -> List[str]: new_modules = [] # Configure/initialise missing/conditional parts @@ -62,6 +83,10 @@ def setup(self) -> List[str]: # Configure adapter/projector if requested if "projector" in branch: + if isinstance(encoder, IdentityEncoder): + raise IllegalArgumentCombination( + "Identity encoder cannot have linear projector" + ) projector = branch["projector"] intermediate_dim = encoder.output_dim @@ -104,7 +129,7 @@ def set_output_dim(self): if self.fusion_strategy == "concat": self.output_dim = sum(output_dims) elif self.fusion_strategy == "mean": - if set(output_dims) != 1: + if len(set(output_dims)) != 1: raise ValueError( f"Encoder branches produces different output dimensions {output_dims} and cannot be averaged." ) @@ -122,8 +147,14 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: branch_feats.append(feats) if self.fusion_strategy == "concat": - return torch.cat(branch_feats, dim=1) - return torch.mean(branch_feats, dim=1) + feats = torch.cat(branch_feats, dim=1) + if self.extra_projector: + feats = self.extra_projector(feats) + else: + feats = torch.stack(branch_feats, dim=0).mean(dim=0) + if self.extra_projector: + feats = self.extra_projector(feats) + return feats @property def device(self): diff --git a/src/models/components/geo_encoders/geoclip.py b/src/models/components/geo_encoders/geoclip.py index 38dffdb..9224d60 100644 --- a/src/models/components/geo_encoders/geoclip.py +++ b/src/models/components/geo_encoders/geoclip.py @@ -21,10 +21,9 @@ def __init__( self.geo_data_name = geo_data_name @override - def setup(self) -> List[str]: + def _setup(self) -> List[str]: self.geo_encoder = LocationEncoder() self.output_dim = self.geo_encoder.LocEnc0.head[0].out_features - print("Model setup with GeoClip coordinate encoder") return [] @override @@ -39,6 +38,8 @@ def forward( if coords.dtype != dtype: coords = coords.to(dtype) feats = self.geo_encoder(coords) + if self.extra_projector: + feats = self.extra_projector(feats) return feats.to(dtype) diff --git a/src/models/components/geo_encoders/identity_encoder.py b/src/models/components/geo_encoders/identity_encoder.py new file mode 100644 index 0000000..29f9bfe --- /dev/null +++ b/src/models/components/geo_encoders/identity_encoder.py @@ -0,0 +1,50 @@ +from typing import Dict, List, override + +import torch +from torch import nn + +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder + + +class IdentityEncoder(BaseGeoEncoder): + def __init__( + self, + geo_data_name="aef", + ) -> None: + """Encoder to avreage tile values into a 1D vector. + + :param geo_data_name: modality name + """ + super().__init__() + + self.dict_n_bands_default = {"s2": 4, "aef": 64, "tessera": 128} + self.allowed_geo_data_names: list[str] = list(self.dict_n_bands_default.keys()) + assert ( + geo_data_name in self.allowed_geo_data_names + ), f"geo_data_name must be one of {self.allowed_geo_data_names}, got {geo_data_name}" + self.geo_data_name = geo_data_name + + @override + def _setup(self) -> List[str]: + """Configures modules and returns newly initialised, trainable module names.""" + + self.output_dim = None + self.geo_encoder = nn.Identity() + return [] + + @override + def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + """Data forward pass through the encoder.""" + tile = batch.get("eo", {}).get(self.geo_data_name) + + if self.output_dim is None: + self.output_dim = tile.shape + + feats = self.geo_encoder(tile) + if self.extra_projector: + feats = self.extra_projector(feats) + return feats + + @property + def device(self): + return diff --git a/src/models/components/geo_encoders/mlp_projector.py b/src/models/components/geo_encoders/mlp_projector.py index e622216..a43316e 100644 --- a/src/models/components/geo_encoders/mlp_projector.py +++ b/src/models/components/geo_encoders/mlp_projector.py @@ -25,9 +25,8 @@ def __init__( self.net: nn.Module | None = None @override - def setup(self) -> List[str]: + def _setup(self) -> List[str]: self.configure_nn() - print("Model setup with MLP projector") return ["net"] def set_input_dim(self, input_dim: int) -> None: @@ -49,4 +48,7 @@ def configure_nn(self) -> None: self.net = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.net(x) + feats = self.net(x) + if self.extra_projector: + feats = self.extra_projector(feats) + return feats diff --git a/src/models/components/geo_encoders/tabular_encoder.py b/src/models/components/geo_encoders/tabular_encoder.py index 1ae4b1d..faff1a4 100644 --- a/src/models/components/geo_encoders/tabular_encoder.py +++ b/src/models/components/geo_encoders/tabular_encoder.py @@ -33,9 +33,8 @@ def __init__( self.geo_data_name = geo_data_name @override - def setup(self, input_dim: int = None) -> list[str]: + def _setup(self, input_dim: int = None) -> list[str]: self.configure_nn(input_dim) - print("Model setup with Tabular geo-encoder") return ["tabular_encoder"] def set_tabular_input_dim(self, input_dim: int) -> None: @@ -68,4 +67,7 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: tab_data = tab_data.to(dtype) feats = self.geo_encoder(tab_data) + if self.extra_projector: + feats = self.extra_projector(feats) + return feats.to(dtype) diff --git a/src/models/components/loss_fns/rrmse_loss.py b/src/models/components/loss_fns/rrmse_loss.py index 1720f1b..8c62f7e 100644 --- a/src/models/components/loss_fns/rrmse_loss.py +++ b/src/models/components/loss_fns/rrmse_loss.py @@ -8,15 +8,13 @@ class RRMSELoss(BaseLossFn): """Relative Root Mean Squared Error (RRMSE). - RRMSE = RMSE / mean(|labels|) + RRMSE = RMSE / mean(labels) - Normalises RMSE by the mean absolute value of the target, giving a - unit-free percentage error. This makes results comparable across crops - and regions with different absolute yield scales (e.g. t/ha ranges - differ significantly between maize in Zambia and rice in Rwanda). + Normalises RMSE by the mean absolute value of the target, giving a unit-free percentage error. + This makes results comparable across crops and regions with different absolute yield scales + (e.g. t/ha ranges differ significantly between maize in Zambia and rice in Rwanda). - Returns a fraction (e.g. 0.15 = 15 % error). Multiply by 100 for - percentage when reporting. + Returns a fraction (e.g. 0.15 = 15 % error). Multiply by 100 for percentage when reporting. """ def __init__(self) -> None: diff --git a/src/models/components/metrics/metrics_wrapper.py b/src/models/components/metrics/metrics_wrapper.py index 33d86de..0d40212 100644 --- a/src/models/components/metrics/metrics_wrapper.py +++ b/src/models/components/metrics/metrics_wrapper.py @@ -19,6 +19,6 @@ def forward(self, mode="train", **kwargs) -> Dict[str, torch.float]: for metric in self.metrics: metric_results = metric(mode=mode, return_label=True, **kwargs) for k, v in metric_results.items(): - compiled_dict[f"{mode}_{k}"] = v + compiled_dict[k] = v return compiled_dict diff --git a/src/models/components/pred_heads/base_pred_head.py b/src/models/components/pred_heads/base_pred_head.py index 5ed63c4..b3acd0c 100644 --- a/src/models/components/pred_heads/base_pred_head.py +++ b/src/models/components/pred_heads/base_pred_head.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import final +from typing import List, final import torch from torch import nn @@ -9,9 +9,14 @@ class BasePredictionHead(nn.Module, ABC): def __init__(self) -> None: """Base prediction head interface class.""" super().__init__() + + # Modules self.net: nn.Module | None = None + self.input_dim: int | None = None self.output_dim: int | None = None + self.setup_flag: bool = False + self.cfg_dict = {} @abstractmethod def forward(self, feats: torch.Tensor) -> torch.Tensor: @@ -34,11 +39,23 @@ def set_dim(self, input_dim: int, output_dim: int) -> None: self.input_dim = input_dim self.output_dim = output_dim - @abstractmethod - def setup(self) -> None: - """Configures networks, data-dependent parts. + @final + def setup(self) -> List[str]: + """Configures modules. Gets called in model.setup() method. Returns names of any new module configured to be added to the trainable modules list. """ + if self.setup_flag: + print(f"Module {self.__str__()} is already set up.") + return [] + else: + self._setup() + print(f"Model set up with {self.__str__()}") + self.setup_flag = True + return ["prediction_head"] + + @abstractmethod + def _setup(self) -> None: + """Configures specific prediction head.""" pass diff --git a/src/models/components/pred_heads/linear_pred_head.py b/src/models/components/pred_heads/linear_pred_head.py index 94338a7..92c7818 100644 --- a/src/models/components/pred_heads/linear_pred_head.py +++ b/src/models/components/pred_heads/linear_pred_head.py @@ -1,4 +1,4 @@ -from typing import override +from typing import List, override import torch from torch import nn @@ -24,16 +24,11 @@ def __init__( @override def forward(self, feats: torch.Tensor) -> torch.Tensor: """Forward pass through the prediction head.""" - return torch.sigmoid(self.net(feats)) @override - def setup(self) -> None: - """Configures networks, data-dependent parts. - - Gets called in model.setup() method. Returns names of any new module configured to be added - to the trainable modules list. - """ + def _setup(self) -> None: + """Configures specific prediction head.""" assert type(self.input_dim) is int, self.input_dim assert type(self.output_dim) is int, self.output_dim self.net = nn.Linear(self.input_dim, self.output_dim) diff --git a/src/models/components/pred_heads/mlp_pred_head.py b/src/models/components/pred_heads/mlp_pred_head.py index 6d4c124..7d4cd48 100644 --- a/src/models/components/pred_heads/mlp_pred_head.py +++ b/src/models/components/pred_heads/mlp_pred_head.py @@ -34,12 +34,8 @@ def forward(self, feats: torch.Tensor) -> torch.Tensor: return torch.sigmoid(self.net(feats)) @override - def setup(self) -> None: - """Configures networks, data-dependent parts. - - Gets called in model.setup() method. Returns names of any new module configured to be added - to the trainable modules list. - """ + def _setup(self) -> None: + """Configures specific prediction head.""" assert type(self.input_dim) is int, self.input_dim assert type(self.output_dim) is int, self.output_dim layers = [] diff --git a/src/models/components/pred_heads/mlp_regression_head.py b/src/models/components/pred_heads/mlp_regression_head.py index d9553ec..10088a1 100644 --- a/src/models/components/pred_heads/mlp_regression_head.py +++ b/src/models/components/pred_heads/mlp_regression_head.py @@ -49,12 +49,8 @@ def forward(self, feats: torch.Tensor) -> torch.Tensor: return self.net(feats) @override - def setup(self) -> None: - """Configures networks, data-dependent parts. - - Gets called in model.setup() method. Returns names of any new module configured to be added - to the trainable modules list. - """ + def _setup(self) -> None: + """Configures specific prediction head.""" assert isinstance(self.input_dim, int), self.input_dim assert isinstance(self.output_dim, int), self.output_dim @@ -70,3 +66,4 @@ def setup(self) -> None: layers.append(nn.Linear(in_dim, self.output_dim)) self.net = nn.Sequential(*layers) + return diff --git a/src/models/components/text_encoders/base_text_encoder.py b/src/models/components/text_encoders/base_text_encoder.py index d934aab..4f54a2f 100644 --- a/src/models/components/text_encoders/base_text_encoder.py +++ b/src/models/components/text_encoders/base_text_encoder.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict +from typing import Dict, List, final import torch from torch import nn @@ -8,12 +8,37 @@ class BaseTextEncoder(nn.Module, ABC): def __init__(self) -> None: super().__init__() + + # modules self.processor: nn.Module | None = None self.model: nn.Module = None self.projector: nn.Module | None = None - self.output_dim: int | None = None self.extra_projector: nn.Module | None = None + self.output_dim: int | None = None + self.setup_flag: bool = False + self.cfg_dict: Dict = {} + + @final + def setup(self): + """Configures modules. + + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. + """ + if self.setup_flag: + print(f"Module {self.__str__()} is already set up.") + return [] + else: + trainable_modules = self._setup() + print(f"Model set up with {self.__str__()}") + self.setup_flag = True + return trainable_modules + + def _setup(self) -> List[str]: + """Configures modules and returns newly initialised, trainable module names.""" + return [] + @abstractmethod def forward(self, batch: Dict[str, torch.Tensor], mode: str) -> torch.Tensor: pass diff --git a/src/models/components/text_encoders/clip_text_encoder.py b/src/models/components/text_encoders/clip_text_encoder.py index 075a390..5052ca5 100644 --- a/src/models/components/text_encoders/clip_text_encoder.py +++ b/src/models/components/text_encoders/clip_text_encoder.py @@ -29,8 +29,6 @@ def __init__(self, hf_cache_dir: str = "../.cache", output_normalization="l2") - self.output_dim = 512 - print("Model set up with CLIP text encoder") - @override def forward(self, batch: Dict[str, torch.Tensor], mode: str) -> torch.Tensor: # Get text inputs diff --git a/src/models/inference_model.py b/src/models/inference_model.py new file mode 100644 index 0000000..5c0023e --- /dev/null +++ b/src/models/inference_model.py @@ -0,0 +1,266 @@ +import os +from typing import Dict, Tuple, override + +import hydra +import torch +import torch.nn.functional as F + +from src.models.base_model import BaseModel +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder +from src.models.components.metrics.metrics_wrapper import MetricsWrapper +from src.models.components.pred_heads.linear_pred_head import BasePredictionHead +from src.models.components.text_encoders.base_text_encoder import ( + BaseTextEncoder, +) +from src.utils import RankedLogger +from src.utils.errors import FileNotSpecified +from src.utils.logging_utils import log_model_loading + +log = RankedLogger(__name__, rank_zero_only=True) + + +class InferenceModel(BaseModel): + def __init__( + self, + geo_encoder: BaseGeoEncoder | None, + text_encoder: BaseTextEncoder | None, + prediction_head: BasePredictionHead | None, + num_classes: int | None, + metrics: MetricsWrapper | None = None, + ks: list[int] | None = [5, 10, 15], + match_to_geo: bool = True, + **kwargs, + ) -> None: + """Inference model. + + :param geo_encoder: module for encoding geo data + :param text_encoder: module for encoding text data + :param prediction_head: module for making prediction from geo features + :param num_classes: number of target classes + :param metrics: metrics to track for model performance estimation + :param ks: list of ks + :param match_to_geo: whether to match dimensions of text encoder to geo_encoder or visa- + versa + """ + + super().__init__( + trainable_modules=[], + geo_encoder=geo_encoder, + text_encoder=text_encoder, + prediction_head=prediction_head, + optimizer=None, + scheduler=None, + loss_fn=None, + metrics=metrics, + num_classes=num_classes, + tabular_dim=None, + ) + + # Params from alignment model + self.match_to_geo = match_to_geo + self.ks = ks + + @override + def _setup(self, stage: str) -> None: + """Set up the network.""" + if stage != "inference": + raise ValueError(f"Trying to {stage} inference model") + + print("-------Model------------") + # Configure encoders + if hasattr(self, "geo_encoder"): + self.geo_encoder.setup() + if hasattr(self, "text_encoder"): + self.text_encoder.setup() + + if hasattr(self, "text_encoder") and hasattr(self, "geo_encoder"): + # Configure optional extra projection so text embeddings match geo embeddings. + if self.text_encoder.output_dim != self.geo_encoder.output_dim: + if self.match_to_geo: + self.text_encoder.add_projector(projected_dim=self.geo_encoder.output_dim) + else: + self.geo_encoder.add_projector(projected_dim=self.text_encoder.output_dim) + # Configure prediction head + if hasattr(self, "prediction_head") and self.prediction_head.net is None: + if self.num_classes is None: + raise ValueError( + "InferenceModel requires `num_classes` to build the prediction head." + ) + self.prediction_head.set_dim( + input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes + ) + self.prediction_head.setup() + print("------------------------") + + @override + def _step( + self, + batch: Dict[str, torch.Tensor], + mode: str = "train", + ) -> torch.Tensor: + """Step forward computation of the model.""" + pass + + @override + def forward( + self, + batch: Dict[str, torch.Tensor], + mode: str = "train", + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Model forward logic.""" + + # Embed modalities + if hasattr(self, "geo_encoder"): + geo_feats = self.geo_encoder(batch) + if hasattr(self, "text_encoder"): + text_feats = self.text_encoder(batch, mode) + if hasattr(self, "prediction_head"): + pred_feats = self.prediction_head(geo_feats) + + # Change dtype of geo data if it does not match text dtype + if ( + hasattr(self, "text_encoder") + and hasattr(self, "geo_encoder") + and geo_feats.dtype != text_feats.dtype + ): + geo_feats = geo_feats.to(text_feats.dtype) + + return pred_feats, geo_feats, text_feats + + def concept_similarities(self, geo_embeds, concept=None) -> torch.Tensor: + # Get concept embeddings + if concept is not None: + # If only one concept is provided + if isinstance(concept, str): + concept = [concept] + with torch.no_grad(): + concept_embeds = self.text_encoder({"text": concept}, mode="train") + + elif self.concept_embeds is not None: + concept_embeds = self.concept_embeds + else: + with torch.no_grad(): + concept_embeds = self.text_encoder({"text": self.concepts}, mode="train") + + # Similarity + geo_embeds = F.normalize(geo_embeds, dim=1) + concept_embeds = F.normalize(concept_embeds, dim=1) + similarity_matrix = concept_embeds @ geo_embeds.T + + return similarity_matrix + + +def _is_prefix_trained(trainable_modules: list[str], prefix: str) -> bool: + """True if any trainable module starts with `prefix` (before dot).""" + return any(m.split(".")[0] == prefix for m in trainable_modules) + + +def load_inference_model(inference_ckpt_path: str) -> InferenceModel: + """Loads inference model from a merged checkpoint. + + :param inference_ckpt_path: path to inference model weights + :return: an InferenceModel with pre-trained weights + """ + inference_ckpt = torch.load(inference_ckpt_path, map_location="cpu", weights_only=False) + model = hydra.utils.instantiate(inference_ckpt["hyper_parameters"]) + model.setup("inference") + res = model.load_state_dict(inference_ckpt["state_dict"], strict=False) + log_model_loading("inference_ckpt", res) + return model + + +def merge_inference_model(cfg, save_ckpt=False) -> InferenceModel | None: + """Configures the inference model from the predictive + alignment checkpoints. + + :param cfg: A DictConfig configuration composed by Hydra. + :param save_ckpt: Whether to save the model or not. + :return: an InferenceModel with pre-trained weights + """ + + # Stitch the inference model from the predictive + alignment checkpoints. + pred_ckpt_path = cfg.get("predictive_ckpt_path") or FileNotSpecified( + 'You must specify predictive model weight path as "predictive_ckpt_path"' + ) + align_ckpt_path = cfg.get("alignment_ckpt_path") or FileNotSpecified( + 'You must specify alignment model weight path as "alignment_ckpt_path"' + ) + # TODO: remove dataset saving into the checkpoint + pred_ckpt = torch.load(pred_ckpt_path, map_location="cpu", weights_only=False) + align_ckpt = torch.load(align_ckpt_path, map_location="cpu", weights_only=False) + + # Sanity check: ensure geo encoder configs match. + align_ckpt["hyper_parameters"]["geo_encoder"] = pred_ckpt["hyper_parameters"].get( + "geo_encoder" + ) + if pred_ckpt["hyper_parameters"].get("geo_encoder") != align_ckpt["hyper_parameters"].get( + "geo_encoder" + ): + log.warning("Geo encoder configs differ between checkpoints; results may be invalid.") + if input("Do you want to proceed? y/n").lower() == "n": + return None + + pred_trainable_modules = pred_ckpt["hyper_parameters"].get("trainable_modules", []) + align_trainable_modules = align_ckpt["hyper_parameters"].get("trainable_modules", []) + + geo_pred_encoder_trained = _is_prefix_trained(pred_trainable_modules, "geo_encoder") + geo_align_encoder_trained = _is_prefix_trained(align_trainable_modules, "geo_encoder") + + if geo_pred_encoder_trained and geo_align_encoder_trained: + raise ValueError("Models are not aligned: both checkpoints trained geo_encoder.") + + # Instantiate InferenceModel via hydra, using alignment encoder configs with prediction model head configs + inference_hparams = align_ckpt["hyper_parameters"] + inference_hparams.update( + { + "_target_": "src.models.inference_model.InferenceModel", + "prediction_head": pred_ckpt["hyper_parameters"].get("prediction_head"), + "num_classes": pred_ckpt["hyper_parameters"].get("num_classes"), + } + ) + inference_hparams["text_encoder"]["hf_cache_dir"] = os.path.join( + cfg.paths.cache_dir, "huggingface" + ) + + model: InferenceModel = hydra.utils.instantiate(inference_hparams) + model.setup("inference") + + # Load alignment weights first (text encoder). + text_state = { + k: v for k, v in align_ckpt["state_dict"].items() if k.startswith("text_encoder.") + } + res = model.load_state_dict(text_state, strict=False) + log_model_loading("text_encoder", res) + + if cfg.training_order[0] == "prediction_model" and not geo_align_encoder_trained: + geo_state = { + k: v for k, v in pred_ckpt["state_dict"].items() if k.startswith("geo_encoder.") + } + else: + geo_state = { + k: v for k, v in align_ckpt["state_dict"].items() if k.startswith("geo_encoder.") + } + res = model.load_state_dict(geo_state, strict=False) + log_model_loading("geo_encoder", res) + + # Load prediction head weights from predictive ckpt. + head_state = { + k: v for k, v in pred_ckpt["state_dict"].items() if k.startswith("prediction_head.") + } + res = model.load_state_dict(head_state, strict=False) + log_model_loading("Predictive_head", res) + + # Save model + if save_ckpt: + save_path = cfg.get("save_inference_ckpt_path") + if not save_path: + print("Model could not be saved as save_path was not provided") + + # Get `state_dict` + state_dict = model.state_dict() + + # Save + os.makedirs(os.path.dirname(save_path), exist_ok=True) + torch.save({"state_dict": state_dict, "hyper_parameters": inference_hparams}, save_path) + log.info(f"Saved merged inference checkpoint to: {save_path}") + + return model diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index 7b951c2..d71224f 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -21,78 +21,77 @@ def __init__( prediction_head: BasePredictionHead, trainable_modules: list[str], optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler, + scheduler: torch.optim.lr_scheduler.LRScheduler | None, loss_fn: BaseLossFn, metrics: MetricsWrapper, + num_classes: int | None = None, + tabular_dim: int | None = None, normalize_features: bool = True, ) -> None: """Implementation of the predictive model with replaceable GEO encoder, and prediction head. - :param geo_encoder: geo encoder module (replaceable) - :param prediction_head: prediction head module (replaceable) - :param trainable_modules: list of modules to train (parts/modules or modules, modules) - :param optimizer: optimizer to use for training - :param scheduler: scheduler to use for training - :param loss_fn: loss function to use - :param metrics: metrics to use for model performance evaluation + :param trainable_modules: which modules to train + :param geo_encoder: module for encoding geo data + :param prediction_head: module for making prediction from geo features + :param optimizer: optimizer for the model weight update + :param scheduler: scheduler for the model weight update + :param loss_fn: loss function + :param metrics: metrics to track for model performance estimation :param num_classes: number of target classes :param tabular_dim: number of tabular features :param normalize_features: if True, apply L2 normalisation to encoder output before the prediction head (default: True) """ - super().__init__(trainable_modules, optimizer, scheduler, loss_fn, metrics) - - # Geo encoder configuration - self.geo_encoder = geo_encoder - - # Prediction head - self.prediction_head = prediction_head + super().__init__( + trainable_modules=trainable_modules, + geo_encoder=geo_encoder, + text_encoder=None, + prediction_head=prediction_head, + optimizer=optimizer, + scheduler=scheduler, + loss_fn=loss_fn, + metrics=metrics, + num_classes=num_classes, + tabular_dim=tabular_dim, + ) # Normalise features boolean self.normalize_features = normalize_features @override - def setup(self, stage: str) -> None: - self.num_classes = self.trainer.datamodule.num_classes - self.tabular_dim = self.trainer.datamodule.tabular_dim + def _setup(self, stage: str) -> None: + """Set up encoders and missing adapters/projectors based data-bound configurations (through + datamodule), This method is called after trainer is initialized and datamodule is + available. - if stage != "fit": - if isinstance(self.trainable_modules, tuple): - self.trainable_modules = list(self.trainable_modules) + Otherwise, some configuration variables must be made available + """ + if stage != "fit" and isinstance(self.trainable_modules, tuple): + self.trainable_modules = list(self.trainable_modules) print("-------Model------------") - self.setup_encoders_adapters() - print("------------------------") - - # Freezing requested parts - self.freezer() - - def setup_encoders_adapters(self): - """Set up encoders and missing adapters/projectors.""" - # TODO: move to multi-modal eo encoder - # If tabular encoder used, we need to specify tabular dim if isinstance(self.geo_encoder, TabularEncoder) or isinstance( self.geo_encoder, EncoderWrapper ): self.geo_encoder.set_tabular_input_dim(self.tabular_dim) - # Setup encoders that need data-depended configurations - new_modules = [f"geo_encoder.{i}]" for i in self.geo_encoder.setup()] + # Setup encoders + new_modules = [f"geo_encoder.{i}" for i in self.geo_encoder.setup() or []] self.trainable_modules.extend(new_modules) # Configure prediction head based on geo-encoder output_dim self.prediction_head.set_dim( input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes ) - self.prediction_head.setup() - if "prediction_head" not in self.trainable_modules: - self.trainable_modules.append("prediction_head") + self.trainable_modules.extend(self.prediction_head.setup() or []) + print("------------------------") @override def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + """Forward pass of a batch through the model.""" feats = self.geo_encoder(batch) if self.normalize_features: feats = F.normalize(feats, dim=-1) @@ -100,19 +99,19 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: @override def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train") -> torch.Tensor: + """Step logic of forward pass, metric calculation.""" + + # Forward pass preds = self.forward(batch) + loss = self.loss_fn(preds, batch.get("target")) + metrics = self.metrics(pred=preds, batch=batch, mode=mode) + + # Logging log_kwargs = dict( on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=preds.size(0) ) - loss = self.loss_fn(preds, batch.get("target")) self.log(f"{mode}_loss", loss, **log_kwargs) - - metrics = self.metrics(pred=preds, batch=batch, mode=mode) self.log_dict(metrics, **log_kwargs) return loss - - -if __name__ == "__main__": - _ = PredictiveModel(None, None, None, None, None, None, None) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 49b8a91..be3f072 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -1,6 +1,6 @@ -from io import text_encoding from typing import Dict, Tuple, override +import numpy as np import torch import torch.nn.functional as F @@ -11,9 +11,6 @@ RetrievalContrastiveValidation, ) from src.models.components.metrics.metrics_wrapper import MetricsWrapper -from src.models.components.pred_heads.linear_pred_head import ( - BasePredictionHead, -) from src.models.components.text_encoders.base_text_encoder import ( BaseTextEncoder, ) @@ -22,75 +19,64 @@ class TextAlignmentModel(BaseModel): def __init__( self, + trainable_modules: list[str], geo_encoder: BaseGeoEncoder, text_encoder: BaseTextEncoder, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, loss_fn: BaseLossFn, - trainable_modules: list[str], metrics: MetricsWrapper, - prediction_head: BasePredictionHead | None = None, + num_classes: int | None = None, + tabular_dim: int | None = None, ks: list[int] | None = [5, 10, 15], match_to_geo: bool = True, ) -> None: """Implementation of contrastive text-eo modality alignment model. - :param geo_encoder: geo encoder module (replaceable) - :param text_encoder: text encoder module (replaceable) - :param optimizer: optimizer to use for training - :param scheduler: scheduler to use for training - :param loss_fn: loss function to use (contrastive) - :param trainable_modules: list of modules to train (parts/modules or modules, modules) - :param metrics: metrics to use for model performance evaluation + :param trainable_modules: which modules to train + :param geo_encoder: module for encoding geo data + :param text_encoder: module for encoding text data + :param optimizer: optimizer for the model weight update + :param scheduler: scheduler for the model weight update + :param loss_fn: loss function + :param metrics: metrics to track for model performance estimation :param num_classes: number of target classes :param tabular_dim: number of tabular features - :param prediction_head: prediction head :param ks: list of ks :param match_to_geo: whether to match dimensions of text encoder to geo_encoder or visa- versa """ - super().__init__(trainable_modules, optimizer, scheduler, loss_fn, metrics) + super().__init__( + trainable_modules=trainable_modules, + geo_encoder=geo_encoder, + text_encoder=text_encoder, + prediction_head=None, + optimizer=optimizer, + scheduler=scheduler, + loss_fn=loss_fn, + metrics=metrics, + num_classes=num_classes, + tabular_dim=tabular_dim, + ) # Metrics self.ks = ks self.log_kwargs = dict(on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) - # Encoders configuration - self.geo_encoder = geo_encoder - self.text_encoder = text_encoder self.match_to_geo = match_to_geo - # Prediction head - self.prediction_head = prediction_head - @override - def setup(self, stage: str) -> None: - self.num_classes = self.trainer.datamodule.num_classes - self.tabular_dim = self.trainer.datamodule.tabular_dim + def _setup(self, stage: str = "fit") -> None: + """Set up encoders and missing adapters/projectors based data-bound configurations (through + datamodule), This method is called after trainer is initialized and datamodule is + available. + Otherwise, some configuration variables must be made available + """ # Set up encoders and missing adapters/projectors print("-------Model------------") - self.setup_encoders_adapters() - print("------------------------") - - # Freeze requested parts - self.freezer() - - # Configure contrastive retrieval evaluation - self.setup_retrieval_evaluation() - - def setup_encoders_adapters(self): - """Set up encoders and missing adapters/projectors.""" - # We don't use tabular encoders for wrapping - # if ( - # isinstance(self.geo_encoder, MultiModalEncoder) - # and self.geo_encoder.use_tabular - # and not self.geo_encoder._tabular_ready - # ): - # self.geo_encoder.build_tabular_branch(self.tabular_dim) - - # Setup encoders that need data-depended configurations - new_modules = [f"geo_encoder.{i}" for i in self.geo_encoder.setup()] + new_modules = [f"geo_encoder.{i}" for i in self.geo_encoder.setup() or []] + new_modules.extend([f"text_encoder.{i}" for i in self.text_encoder.setup() or []]) self.trainable_modules.extend(new_modules) # Extra projector for text encoder if eo and text dim not match @@ -102,21 +88,51 @@ def setup_encoders_adapters(self): self.geo_encoder.add_projector(projected_dim=self.text_encoder.output_dim) self.trainable_modules.append("geo_encoder.extra_projector") - # Configure prediction head based on geo-encoder output_dim - if self.prediction_head is not None: - self.prediction_head.set_dim( - input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes - ) - self.prediction_head.setup() - - # # Unify dtypes -> moving to data part, rather than changing parameter type - # if self.geo_encoder.dtype != self.text_encoder.dtype: - # self.geo_encoder = self.geo_encoder.to(self.text_encoder.dtype) - # print(f"Geo encoder dtype changed to {self.geo_encoder.dtype}") + # Configure contrastive retrieval evaluation + self.setup_retrieval_evaluation() + print("------------------------") def setup_retrieval_evaluation(self): self.concept_configs = self.trainer.datamodule.concept_configs self.concepts = [c["concept_caption"] for c in self.concept_configs] + self.concept_names = [ + f"{c['col'].replace('aux_', '')}_{'max' if c['is_max'] else 'min'}" + for c in self.concept_configs + ] + + dataset_names = ["train", "val", "test"] + self.dynamic_k_baselines = {} + for dataset_name in dataset_names: + if not hasattr(self.trainer.datamodule, f"data_{dataset_name}"): + continue + + tmp_ds = getattr(self.trainer.datamodule, f"data_{dataset_name}") + n_ds = len(tmp_ds) + self.dynamic_k_baselines[dataset_name] = {} + + # Placeholder for all concepts + aux_vals_per_concept = {i: [] for i in range(len(self.concept_configs))} + + for item in tmp_ds: + aux_data = item["aux"]["aux"] + for i_c, c in enumerate(self.concept_configs): + aux_col_id = c["id"] + aux_vals_per_concept[i_c].append(aux_data[aux_col_id]) + + # Compute per concept + for i_c, c in enumerate(self.concept_configs): + c_name = self.concept_names[i_c] + aux_vals_current_ds = aux_vals_per_concept[i_c] + + theta_k = self.find_elbow_point(aux_vals_current_ds) + self.concept_configs[i_c][ + "theta_k" + ] = theta_k # assign new theta_k to concept_configs for later use in validation + if c["is_max"]: + n_baseline = sum(aux_val >= theta_k for aux_val in aux_vals_current_ds) + else: + n_baseline = sum(aux_val <= theta_k for aux_val in aux_vals_current_ds) + self.dynamic_k_baselines[dataset_name][c_name] = n_baseline / n_ds * 100 self.contrastive_val = RetrievalContrastiveValidation(self.ks, self.concept_configs) self.outputs_epoch_memory = [] @@ -214,6 +230,12 @@ def _on_epoch_end(self, mode: str): for i, result in concept_scores.items(): print(f'\nConcept "{self.concepts[i]}" average top-k accuracies in {mode} split:') for k, v in result.items(): + if k == "dynamic_k": + self.log(f"dyn_k_{self.concept_names[i]}", v, **self.log_kwargs) + indexed_v = (v - self.dynamic_k_baselines[mode][self.concept_names[i]]) / ( + 100 - self.dynamic_k_baselines[mode][self.concept_names[i]] + ) + self.log(f"dyn_k_index_{self.concept_names[i]}", indexed_v, **self.log_kwargs) print(f"Top-{k}: {v:.1f}%") avr_scores[f"{mode}_avr_top-{k}"].append(v) @@ -254,3 +276,23 @@ def concept_similarities(self, geo_embeds, concept=None) -> torch.Tensor: similarity_matrix = concept_embeds @ geo_embeds.T return similarity_matrix + + @staticmethod + def find_elbow_point(vals): + vals = np.sort(vals) + x = np.arange(len(vals)) / len(vals) + y = vals + slope = (y[-1] - y[0]) / (x[-1] - x[0]) # diagonal from first to last point + intercept = y[0] - slope * x[0] + orthogonal_slope = -1 / slope + + intercepts_orthogonal = y - orthogonal_slope * x + intersection_diagonal_orthogonal = (intercepts_orthogonal - intercept) / ( + slope - orthogonal_slope + ) + distances = np.sqrt( + (x - intersection_diagonal_orthogonal) ** 2 + (y - (slope * x + intercept)) ** 2 + ) # distance to diagonal + elbow_index = np.argmax(distances) + elbow_point = y[elbow_index] + return elbow_point diff --git a/src/train.py b/src/train.py index 8348fbb..af1686f 100644 --- a/src/train.py +++ b/src/train.py @@ -6,7 +6,7 @@ from dotenv import load_dotenv from lightning import Callback, LightningModule, Trainer from lightning.pytorch.loggers import Logger -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from src.data.base_datamodule import BaseDataModule @@ -53,6 +53,10 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: log.info(f"Instantiating model <{cfg.model._target_}>") model: LightningModule = hydra.utils.instantiate(cfg.model) + # Append model hparams from config to be saved in ckpg + raw_model_cfg = OmegaConf.to_container(cfg.model, resolve=True) + model.update_configs(raw_model_cfg) + log.info("Instantiating callbacks...") callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py new file mode 100644 index 0000000..3d9752e --- /dev/null +++ b/src/utils/data_utils.py @@ -0,0 +1,8 @@ +def center_crop_npy(arr, target_shape): + """Crops npy array to desired size.""" + slices = [] + for dim, target in zip(arr.shape, target_shape): + start = (dim - target) // 2 + end = start + target + slices.append(slice(start, end)) + return arr[tuple(slices)] diff --git a/src/utils/errors.py b/src/utils/errors.py index 910dea5..aa88ce0 100644 --- a/src/utils/errors.py +++ b/src/utils/errors.py @@ -1,2 +1,10 @@ class IllegalArgumentCombination(ValueError): + """Error for illogical argument configuration.""" + + pass + + +class FileNotSpecified(ValueError): + """Error for missing file path specification.""" + pass diff --git a/src/utils/logging_utils.py b/src/utils/logging_utils.py index 4895a96..f23f4df 100644 --- a/src/utils/logging_utils.py +++ b/src/utils/logging_utils.py @@ -59,3 +59,27 @@ def log_hyperparameters(object_dict: Dict[str, Any]) -> None: # send hparams to all loggers for logger in trainer.loggers: logger.log_hyperparams(hparams) + + +def _group_keys(keys: list[str]) -> dict[str, list[str]]: + """Groups module names (keys)""" + grouped: dict[str, list[str]] = {} + for k in keys: + top = k.split(".", 1)[0] if "." in k else k + grouped.setdefault(top, []).append(k) + return grouped + + +def log_model_loading(tag: str, result) -> None: + """Log missing/unexpected keys from `load_state_dict`.""" + missing_keys, unexpected_keys = result + if missing_keys: + grouped = _group_keys(list(missing_keys)) + summary = {k: len(v) for k, v in grouped.items()} + log.warning(f"[{tag}] Missing keys: {summary}") + if unexpected_keys: + grouped = _group_keys(list(unexpected_keys)) + summary = {k: len(v) for k, v in grouped.items()} + log.warning(f"[{tag}] Unexpected keys: {summary}") + if not missing_keys and not unexpected_keys: + log.info(f"[{tag}] Module weights loaded successfully") diff --git a/tests/test_pred_heads.py b/tests/test_pred_heads.py index abc99b4..6f617c2 100644 --- a/tests/test_pred_heads.py +++ b/tests/test_pred_heads.py @@ -16,6 +16,7 @@ # @pytest.mark.slow def test_pred_head_generic_properties(create_butterfly_dataset): + """Test required properties of the prediction head class.""" ds, dm = create_butterfly_dataset batch = next(iter(dm.train_dataloader())) eo_encoder = GeoClipCoordinateEncoder() @@ -23,41 +24,34 @@ def test_pred_head_generic_properties(create_butterfly_dataset): feats = eo_encoder.forward(batch) list_pred_heads = [LinearPredictionHead, MLPPredictionHead, MLPRegressionPredictionHead] - for pred_head_class in list_pred_heads: - pred_head = pred_head_class(input_dim=64, output_dim=64) - pred_head.setup() - assert hasattr( - pred_head, "set_dim" - ), f"'set_dim' method missing in {pred_head_class.__name__}." + for pred_head in list_pred_heads: + pred_head = pred_head() + assert hasattr(pred_head, "set_dim"), f"'set_dim' method missing in {pred_head.__name__}." assert callable( getattr(pred_head, "set_dim") - ), f"'set_dim' is not callable in {pred_head_class.__name__}." - pred_head.set_dim(eo_encoder.output_dim, ds.num_classes) + ), f"'set_dim' is not callable in {pred_head.__name__}." + pred_head.set_dim(input_dim=eo_encoder.output_dim, output_dim=ds.num_classes) assert hasattr( pred_head, "input_dim" - ), f"'input_dim' attribute missing in {pred_head_class.__name__}." + ), f"'input_dim' attribute missing in {pred_head.__name__}." assert hasattr( pred_head, "output_dim" - ), f"'output_dim' attribute missing in {pred_head_class.__name__}." - assert hasattr( - pred_head, "setup" - ), f"'setup' method missing in {pred_head_class.__name__}." + ), f"'output_dim' attribute missing in {pred_head.__name__}." + assert hasattr(pred_head, "setup"), f"'setup' method missing in {pred_head.__name__}." assert callable( getattr(pred_head, "setup") - ), f"'setup' is not callable in {pred_head_class.__name__}." + ), f"'setup' is not callable in {pred_head.__name__}." pred_head.setup() - assert hasattr(pred_head, "net"), f"'net' attribute missing in {pred_head_class.__name__}." - assert hasattr( - pred_head, "forward" - ), f"'forward' method missing in {pred_head_class.__name__}." + assert hasattr(pred_head, "net"), f"'net' attribute missing in {pred_head.__name__}." + assert hasattr(pred_head, "forward"), f"'forward' method missing in {pred_head.__name__}." assert callable( getattr(pred_head, "forward") - ), f"'forward' is not callable in {pred_head_class.__name__}." + ), f"'forward' is not callable in {pred_head.__name__}." out = pred_head.forward(feats) assert isinstance( out, torch.Tensor - ), f"'forward' method of {pred_head_class.__name__} does not return a torch.Tensor." + ), f"'forward' method of {pred_head.__name__} does not return a torch.Tensor." assert out.shape == ( dm.batch_size_per_device, ds.num_classes, - ), f"Output shape mismatch in {pred_head_class.__name__}." + ), f"Output shape mismatch in {pred_head.__name__}."