From 2884fd0a3c77c86ebce755e5bbbae1d59215d3f8 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 13 May 2026 09:46:03 +0200 Subject: [PATCH 1/6] Introduce dtype parameter in configs --- src/data/base_dataset.py | 124 +++++++++++++++++++++++++--------- src/data/butterfly_dataset.py | 40 +++++++---- 2 files changed, 120 insertions(+), 44 deletions(-) diff --git a/src/data/base_dataset.py b/src/data/base_dataset.py index 66349db..6aebc3c 100644 --- a/src/data/base_dataset.py +++ b/src/data/base_dataset.py @@ -8,25 +8,33 @@ import torch from torch.utils.data import Dataset -import src.data_preprocessing.data_utils as du from src.utils.data_utils import center_crop_npy +TORCH_DTYPES = { + 'float32': torch.float32, + 'float64': torch.float64, + 'int32': torch.int32, + 'int64': torch.int64, + 'bfloat16': torch.bfloat16, +} + class BaseDataset(Dataset, ABC): def __init__( - self, - data_dir: str, - modalities: dict, - use_target_data: bool = True, - use_aux_data: Dict[str, List[str] | str] | str | None = None, - dataset_name: str = "BaseDataset", - seed: int = 12345, - mode: str = "train", - cache_dir: str = None, - implemented_mod: set[str] = None, - mock: bool = False, - use_features: bool = True, - csv_name: str = None, + self, + data_dir: str, + modalities: dict, + use_target_data: bool = True, + use_aux_data: Dict[str, List[str] | str] | str | None = None, + dataset_name: str = "BaseDataset", + seed: int = 12345, + mode: str = "train", + cache_dir: str = None, + implemented_mod: set[str] = None, + mock: bool = False, + use_features: bool = True, + csv_name: str = None, + dtype: str = "float32", ) -> None: """Interface for any use case dataset. @@ -50,18 +58,31 @@ def __init__( :param implemented_mod: implemented modalities for each dataset :param mock: whether to mock csv file :param use_features: if tabular feat_* columns should be included. Default True. + :param dtype: global dtype (used if not specified for each modality individually), also used for aux, target """ if mock: dataset_name = "mock" + # Dtype + assert dtype in TORCH_DTYPES.keys() + self.dtype: str = TORCH_DTYPES[dtype] + # Modalities self.implemented_mod = implemented_mod self.modalities: dict = modalities - for mod in self.modalities.keys(): + + # Check modalities and set dtypes + for mod, configs in self.modalities.items(): if mod not in self.implemented_mod: raise ValueError(f"{mod} not in implemented modalities.") - # more precise dataset name (with modalities) + + if 'dtype' in configs: + m_dtype = configs.get('dtype', dtype) + self.modalities[mod]['dtype'] = m_dtype + print(f'Dtype of {mod} set to {m_dtype}') + + # More precise dataset name (with modalities) self.dataset_name: str = dataset_name + "_" + "_".join(modalities) # Set data attributes @@ -288,35 +309,74 @@ def pooch_setup(self) -> None: self.pooch_cli.load_registry(self.registry_path) @final - def load_npy(self, filepath: str) -> torch.Tensor: + def load_npy(self, filepath: str, dtype: np.dtype) -> np.ndarray: """Loads numpy array from file as a tensor.""" - im = np.load(filepath).transpose(2, 0, 1) - return torch.from_numpy(im).float() + arr = np.load(filepath).transpose(2, 0, 1) + if arr.dtype != np.dtype(dtype): + arr = arr.astype(dtype=dtype, copy=False) + + return arr + + @final + def load_tiff(self, tiff_file_path: str, dtype: np.dtype) -> np.ndarray: + """Load tiff file as np array of a specified dtype""" + + with rasterio.open(tiff_file_path) as f: + im = f.read() + assert isinstance(im, np.ndarray) + if im.dtype != np.dtype(dtype): + im = im.astype(dtype=dtype, copy=False) + return im @final def load_aef(self, filepath: str): """Loads AEF data from file as a tensor.""" - im = du.load_tiff(filepath, datatype="np") + # Modality settings size = self.modalities["aef"]["size"] - if im.shape[1] != size: + dtype = self.modalities["aef"].get("dtype") + dtype, is_bfloat16 = self.resolve_dtype(dtype) + + im = self.load_tiff(filepath, np.dtype(dtype)) + + if im.shape[-2:] != (size, size): im = center_crop_npy(im, (64, size, size)) - if np.isinf(im).any(): - im = np.clip(im, a_min=-0.5, a_max=0.5) - # TODO any normalisation needed - return torch.tensor(im).float() + # Scan for inf values and clip them (in memory) + np.clip(im, -0.5, 0.5, out=im) + # TODO any other normalisation needed + + tensor = torch.from_numpy(im) + if is_bfloat16: + tensor = tensor.to(torch.bfloat16) + return tensor @final def load_tessera(self, filepath: str) -> torch.Tensor: """Loads.""" size = self.modalities["tessera"]["size"] - arr = self.load_npy(filepath) - if arr.size()[1] < size: - raise ValueError( - f"Requested tile size {size} is larger than actual available tile size {arr.size()[1]}" - ) - elif arr.size()[1] != size: + dtype = self.modalities["tessera"]["dtype"] + dtype, is_bfloat16 = self.resolve_dtype(dtype) + + arr = self.load_npy(filepath, np.dtype(dtype)) + + if arr.shape[-2:] != (size, size): arr = center_crop_npy(arr, (128, size, size)) + + # Nans are 0 across all 128 channels + # mask = np.all(arr == 0, axis=0) + # arr[mask] = torch.nan # TODO any normalisation needed - return arr + + tensor = torch.from_numpy(arr) + if is_bfloat16: + tensor = tensor.to(torch.bfloat16) + return tensor + + @staticmethod + def resolve_dtype(dtype_str: str): + """Resolve dtype from string into numpy dtype and return flag for mixed precision dtype in tensors""" + is_bfloat16 = dtype_str == "bfloat16" + np_dtype = np.float32 if is_bfloat16 else np.dtype(dtype_str) + + return np_dtype, is_bfloat16 diff --git a/src/data/butterfly_dataset.py b/src/data/butterfly_dataset.py index 642bf9c..5882791 100644 --- a/src/data/butterfly_dataset.py +++ b/src/data/butterfly_dataset.py @@ -5,9 +5,9 @@ import pooch import torch -import src.data_preprocessing.data_utils as du -from src.data.base_dataset import BaseDataset +from src.data.base_dataset import BaseDataset, TORCH_DTYPES from src.data_preprocessing.renaming_utils import rename_s2bms +from src.utils.data_utils import center_crop_npy from src.utils.errors import IllegalArgumentCombination @@ -122,24 +122,40 @@ def zscore_image(self, im: np.ndarray): return im def load_s2(self, filepath: str): - im = du.load_tiff(filepath, datatype="np") + """Loads s2 image tile from file as a tensor.""" - if self.modalities["s2"]["channels"] == "4c": - pass - elif self.modalities["s2"]["channels"] == "rgb": + # Modality settings + size = self.modalities["s2"]["size"] + np_dtype, is_bfloat16 = self.resolve_dtype(self.modalities["s2"]["dtype"]) + + im = self.load_tiff(filepath, dtype=np.dtype('uint16')) + if self.modalities["s2"].get("channels", '') == "4c": + c = 4 + elif self.modalities["s2"].egt("channels", '') == "rgb": im = im[:3, :, :] + c = 3 else: raise IllegalArgumentCombination( - f"Channel specification {self.n_bands} is not implemented." + f"Channel specification {self.modalities["s2"].get("channels", 'null')} is not implemented." ) - if self.modalities["s2"]["preprocessing"] == "zscored": + if self.modalities["s2"].get("preprocessing") == "zscored": im = im.astype(np.int32) im = self.zscore_image(im) else: im = np.clip(im, 0, 2000) im = im / 2000.0 - return torch.tensor(im).float() + + im = im.astype(dtype=np_dtype) + + # Crop + if im.shape[-2:] != (size, size): + im = center_crop_npy(im, (c, size, size)) + + tensor = torch.from_numpy(im) + if is_bfloat16: + tensor = tensor.to(torch.bfloat16) + return tensor @override def __getitem__(self, idx: int) -> Dict[str, Any]: @@ -149,7 +165,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: for modality in self.modalities: if modality in ["coords"]: - formatted_row["eo"][modality] = torch.tensor([row["lat"], row["lon"]]) + formatted_row["eo"][modality] = torch.tensor([row["lat"], row["lon"]], dtype=TORCH_DTYPES[self.modalities[modality]['dtype']]) elif modality == "s2": formatted_row["eo"][modality] = self.load_s2(row["s2_path"]) # TODO: augmentations @@ -160,7 +176,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: if self.use_target_data: formatted_row["target"] = torch.tensor( - [row[k] for k in self.target_names], dtype=torch.float32 + [row[k] for k in self.target_names], dtype=self.dtype ) if self.use_aux_data: @@ -168,7 +184,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: for aux_cat, vals in self.use_aux_data.items(): if aux_cat == "aux": formatted_row["aux"][aux_cat] = torch.tensor( - [row[v] for v in vals], dtype=torch.float32 + [row[v] for v in vals], dtype=self.dtype ) else: formatted_row["aux"][aux_cat] = [row[v] for v in vals] From db816c472ffde87c968be52a3970711aaa6fbc9f Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 13 May 2026 09:46:51 +0200 Subject: [PATCH 2/6] Add dtype property for prediction heads --- .../components/pred_heads/base_pred_head.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/models/components/pred_heads/base_pred_head.py b/src/models/components/pred_heads/base_pred_head.py index b3acd0c..02e2af6 100644 --- a/src/models/components/pred_heads/base_pred_head.py +++ b/src/models/components/pred_heads/base_pred_head.py @@ -59,3 +59,21 @@ def setup(self) -> List[str]: def _setup(self) -> None: """Configures specific prediction head.""" pass + + @property + def device(self) -> torch.device | None: + devices = {p.device for p in self.parameters()} + if len(devices) > 1: + raise RuntimeError("Prediction head is on multiple devices") + elif len(devices) == 0: + return None + return devices.pop() + + @property + def dtype(self) -> torch.dtype | None: + dtypes = {p.dtype for p in self.parameters()} + if len(dtypes) > 1: + raise RuntimeError("Prediction head has multiple dtypes") + elif len(dtypes) == 0: + return None + return dtypes.pop() \ No newline at end of file From 6b975f6f3469ade88d78769eb04aff3b2a8c2e59 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 13 May 2026 09:48:03 +0200 Subject: [PATCH 3/6] Add dtype coherence check for prediction head and encoder --- src/models/predictive_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index 9d8ab9e..83b1d14 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -121,6 +121,9 @@ def _setup_encoders_adapters(self): input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes ) self.prediction_head.setup() + if self.prediction_head.dtype != self.geo_encoder.dtype: + self.prediction_head = self.prediction_head.to(dtype=self.geo_encoder.dtype) + if "prediction_head" not in self.trainable_modules: self.trainable_modules.append("prediction_head") From 4abd5489b35a9bf6f3513916edb001bad7c99854 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 13 May 2026 09:48:26 +0200 Subject: [PATCH 4/6] Set normalisation dtype as geo_encoder --- src/models/predictive_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index 83b1d14..9e3c28a 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -112,7 +112,7 @@ def _setup_encoders_adapters(self): self.trainable_modules.extend(new_modules) if self.normalize_features: - self.normalizer = nn.LayerNorm(self.geo_encoder.output_dim) + self.normalizer = nn.LayerNorm(self.geo_encoder.output_dim, dtype=self.geo_encoder.dtype) self.trainable_modules.append("normalizer") print("Model set up to normalise geo_encoder features.") From 8c06519eb986a27450f00b880a9680d8e0a7f9cb Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 13 May 2026 10:02:13 +0200 Subject: [PATCH 5/6] Fix dtype for when no modality parameters are provided --- src/data/base_dataset.py | 55 ++++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/src/data/base_dataset.py b/src/data/base_dataset.py index 6aebc3c..56ae6f5 100644 --- a/src/data/base_dataset.py +++ b/src/data/base_dataset.py @@ -5,36 +5,37 @@ import numpy as np import pandas as pd +import rasterio import torch from torch.utils.data import Dataset from src.utils.data_utils import center_crop_npy TORCH_DTYPES = { - 'float32': torch.float32, - 'float64': torch.float64, - 'int32': torch.int32, - 'int64': torch.int64, - 'bfloat16': torch.bfloat16, + "float32": torch.float32, + "float64": torch.float64, + "int32": torch.int32, + "int64": torch.int64, + "bfloat16": torch.bfloat16, } class BaseDataset(Dataset, ABC): def __init__( - self, - data_dir: str, - modalities: dict, - use_target_data: bool = True, - use_aux_data: Dict[str, List[str] | str] | str | None = None, - dataset_name: str = "BaseDataset", - seed: int = 12345, - mode: str = "train", - cache_dir: str = None, - implemented_mod: set[str] = None, - mock: bool = False, - use_features: bool = True, - csv_name: str = None, - dtype: str = "float32", + self, + data_dir: str, + modalities: dict, + use_target_data: bool = True, + use_aux_data: Dict[str, List[str] | str] | str | None = None, + dataset_name: str = "BaseDataset", + seed: int = 12345, + mode: str = "train", + cache_dir: str = None, + implemented_mod: set[str] = None, + mock: bool = False, + use_features: bool = True, + csv_name: str = None, + dtype: str = "float32", ) -> None: """Interface for any use case dataset. @@ -77,10 +78,13 @@ def __init__( if mod not in self.implemented_mod: raise ValueError(f"{mod} not in implemented modalities.") - if 'dtype' in configs: - m_dtype = configs.get('dtype', dtype) - self.modalities[mod]['dtype'] = m_dtype - print(f'Dtype of {mod} set to {m_dtype}') + if configs is not None: + m_dtype = configs.get("dtype", dtype) + self.modalities[mod]["dtype"] = m_dtype + print(f"Dtype of {mod} set to {m_dtype}") + else: + m_dtype = dtype + self.modalities[mod] = {"dtype": m_dtype} # More precise dataset name (with modalities) self.dataset_name: str = dataset_name + "_" + "_".join(modalities) @@ -319,7 +323,7 @@ def load_npy(self, filepath: str, dtype: np.dtype) -> np.ndarray: @final def load_tiff(self, tiff_file_path: str, dtype: np.dtype) -> np.ndarray: - """Load tiff file as np array of a specified dtype""" + """Load tiff file as np array of a specified dtype.""" with rasterio.open(tiff_file_path) as f: im = f.read() @@ -375,7 +379,8 @@ def load_tessera(self, filepath: str) -> torch.Tensor: @staticmethod def resolve_dtype(dtype_str: str): - """Resolve dtype from string into numpy dtype and return flag for mixed precision dtype in tensors""" + """Resolve dtype from string into numpy dtype and return flag for mixed precision dtype in + tensors.""" is_bfloat16 = dtype_str == "bfloat16" np_dtype = np.float32 if is_bfloat16 else np.dtype(dtype_str) From 0289645d4a68c9a9e0b152de397600fecb14e7d0 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 13 May 2026 11:18:37 +0200 Subject: [PATCH 6/6] Add dtype forwarding from butterfly_dataset to base_dataset --- src/data/butterfly_dataset.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/data/butterfly_dataset.py b/src/data/butterfly_dataset.py index 5882791..f33fcd4 100644 --- a/src/data/butterfly_dataset.py +++ b/src/data/butterfly_dataset.py @@ -21,6 +21,7 @@ def __init__( seed: int = 12345, cache_dir: str = None, mock: bool = False, + dtype: str = "float32", ) -> None: """A dataset implementation for the Butterfly diversity use case. @@ -32,6 +33,8 @@ def __init__( :param seed: random seed :param cache_dir: path to cache dir :param mock: whether to mock csv file + :param dtype: global dtype (used if not specified for each modality individually), also used for aux, target + """ super().__init__( @@ -44,6 +47,7 @@ def __init__( cache_dir=cache_dir, implemented_mod={"s2", "tessera", "coords", "aef"}, mock=mock, + dtype=dtype ) def setup(self):