diff --git a/src/data/base_dataset.py b/src/data/base_dataset.py index 66349db..56ae6f5 100644 --- a/src/data/base_dataset.py +++ b/src/data/base_dataset.py @@ -5,12 +5,20 @@ import numpy as np import pandas as pd +import rasterio import torch from torch.utils.data import Dataset -import src.data_preprocessing.data_utils as du from src.utils.data_utils import center_crop_npy +TORCH_DTYPES = { + "float32": torch.float32, + "float64": torch.float64, + "int32": torch.int32, + "int64": torch.int64, + "bfloat16": torch.bfloat16, +} + class BaseDataset(Dataset, ABC): def __init__( @@ -27,6 +35,7 @@ def __init__( mock: bool = False, use_features: bool = True, csv_name: str = None, + dtype: str = "float32", ) -> None: """Interface for any use case dataset. @@ -50,18 +59,34 @@ def __init__( :param implemented_mod: implemented modalities for each dataset :param mock: whether to mock csv file :param use_features: if tabular feat_* columns should be included. Default True. + :param dtype: global dtype (used if not specified for each modality individually), also used for aux, target """ if mock: dataset_name = "mock" + # Dtype + assert dtype in TORCH_DTYPES.keys() + self.dtype: str = TORCH_DTYPES[dtype] + # Modalities self.implemented_mod = implemented_mod self.modalities: dict = modalities - for mod in self.modalities.keys(): + + # Check modalities and set dtypes + for mod, configs in self.modalities.items(): if mod not in self.implemented_mod: raise ValueError(f"{mod} not in implemented modalities.") - # more precise dataset name (with modalities) + + if configs is not None: + m_dtype = configs.get("dtype", dtype) + self.modalities[mod]["dtype"] = m_dtype + print(f"Dtype of {mod} set to {m_dtype}") + else: + m_dtype = dtype + self.modalities[mod] = {"dtype": m_dtype} + + # More precise dataset name (with modalities) self.dataset_name: str = dataset_name + "_" + "_".join(modalities) # Set data attributes @@ -288,35 +313,75 @@ def pooch_setup(self) -> None: self.pooch_cli.load_registry(self.registry_path) @final - def load_npy(self, filepath: str) -> torch.Tensor: + def load_npy(self, filepath: str, dtype: np.dtype) -> np.ndarray: """Loads numpy array from file as a tensor.""" - im = np.load(filepath).transpose(2, 0, 1) - return torch.from_numpy(im).float() + arr = np.load(filepath).transpose(2, 0, 1) + if arr.dtype != np.dtype(dtype): + arr = arr.astype(dtype=dtype, copy=False) + + return arr + + @final + def load_tiff(self, tiff_file_path: str, dtype: np.dtype) -> np.ndarray: + """Load tiff file as np array of a specified dtype.""" + + with rasterio.open(tiff_file_path) as f: + im = f.read() + assert isinstance(im, np.ndarray) + if im.dtype != np.dtype(dtype): + im = im.astype(dtype=dtype, copy=False) + return im @final def load_aef(self, filepath: str): """Loads AEF data from file as a tensor.""" - im = du.load_tiff(filepath, datatype="np") + # Modality settings size = self.modalities["aef"]["size"] - if im.shape[1] != size: + dtype = self.modalities["aef"].get("dtype") + dtype, is_bfloat16 = self.resolve_dtype(dtype) + + im = self.load_tiff(filepath, np.dtype(dtype)) + + if im.shape[-2:] != (size, size): im = center_crop_npy(im, (64, size, size)) - if np.isinf(im).any(): - im = np.clip(im, a_min=-0.5, a_max=0.5) - # TODO any normalisation needed - return torch.tensor(im).float() + # Scan for inf values and clip them (in memory) + np.clip(im, -0.5, 0.5, out=im) + # TODO any other normalisation needed + + tensor = torch.from_numpy(im) + if is_bfloat16: + tensor = tensor.to(torch.bfloat16) + return tensor @final def load_tessera(self, filepath: str) -> torch.Tensor: """Loads.""" size = self.modalities["tessera"]["size"] - arr = self.load_npy(filepath) - if arr.size()[1] < size: - raise ValueError( - f"Requested tile size {size} is larger than actual available tile size {arr.size()[1]}" - ) - elif arr.size()[1] != size: + dtype = self.modalities["tessera"]["dtype"] + dtype, is_bfloat16 = self.resolve_dtype(dtype) + + arr = self.load_npy(filepath, np.dtype(dtype)) + + if arr.shape[-2:] != (size, size): arr = center_crop_npy(arr, (128, size, size)) + + # Nans are 0 across all 128 channels + # mask = np.all(arr == 0, axis=0) + # arr[mask] = torch.nan # TODO any normalisation needed - return arr + + tensor = torch.from_numpy(arr) + if is_bfloat16: + tensor = tensor.to(torch.bfloat16) + return tensor + + @staticmethod + def resolve_dtype(dtype_str: str): + """Resolve dtype from string into numpy dtype and return flag for mixed precision dtype in + tensors.""" + is_bfloat16 = dtype_str == "bfloat16" + np_dtype = np.float32 if is_bfloat16 else np.dtype(dtype_str) + + return np_dtype, is_bfloat16 diff --git a/src/data/butterfly_dataset.py b/src/data/butterfly_dataset.py index 642bf9c..f33fcd4 100644 --- a/src/data/butterfly_dataset.py +++ b/src/data/butterfly_dataset.py @@ -5,9 +5,9 @@ import pooch import torch -import src.data_preprocessing.data_utils as du -from src.data.base_dataset import BaseDataset +from src.data.base_dataset import BaseDataset, TORCH_DTYPES from src.data_preprocessing.renaming_utils import rename_s2bms +from src.utils.data_utils import center_crop_npy from src.utils.errors import IllegalArgumentCombination @@ -21,6 +21,7 @@ def __init__( seed: int = 12345, cache_dir: str = None, mock: bool = False, + dtype: str = "float32", ) -> None: """A dataset implementation for the Butterfly diversity use case. @@ -32,6 +33,8 @@ def __init__( :param seed: random seed :param cache_dir: path to cache dir :param mock: whether to mock csv file + :param dtype: global dtype (used if not specified for each modality individually), also used for aux, target + """ super().__init__( @@ -44,6 +47,7 @@ def __init__( cache_dir=cache_dir, implemented_mod={"s2", "tessera", "coords", "aef"}, mock=mock, + dtype=dtype ) def setup(self): @@ -122,24 +126,40 @@ def zscore_image(self, im: np.ndarray): return im def load_s2(self, filepath: str): - im = du.load_tiff(filepath, datatype="np") + """Loads s2 image tile from file as a tensor.""" + + # Modality settings + size = self.modalities["s2"]["size"] + np_dtype, is_bfloat16 = self.resolve_dtype(self.modalities["s2"]["dtype"]) - if self.modalities["s2"]["channels"] == "4c": - pass - elif self.modalities["s2"]["channels"] == "rgb": + im = self.load_tiff(filepath, dtype=np.dtype('uint16')) + if self.modalities["s2"].get("channels", '') == "4c": + c = 4 + elif self.modalities["s2"].egt("channels", '') == "rgb": im = im[:3, :, :] + c = 3 else: raise IllegalArgumentCombination( - f"Channel specification {self.n_bands} is not implemented." + f"Channel specification {self.modalities["s2"].get("channels", 'null')} is not implemented." ) - if self.modalities["s2"]["preprocessing"] == "zscored": + if self.modalities["s2"].get("preprocessing") == "zscored": im = im.astype(np.int32) im = self.zscore_image(im) else: im = np.clip(im, 0, 2000) im = im / 2000.0 - return torch.tensor(im).float() + + im = im.astype(dtype=np_dtype) + + # Crop + if im.shape[-2:] != (size, size): + im = center_crop_npy(im, (c, size, size)) + + tensor = torch.from_numpy(im) + if is_bfloat16: + tensor = tensor.to(torch.bfloat16) + return tensor @override def __getitem__(self, idx: int) -> Dict[str, Any]: @@ -149,7 +169,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: for modality in self.modalities: if modality in ["coords"]: - formatted_row["eo"][modality] = torch.tensor([row["lat"], row["lon"]]) + formatted_row["eo"][modality] = torch.tensor([row["lat"], row["lon"]], dtype=TORCH_DTYPES[self.modalities[modality]['dtype']]) elif modality == "s2": formatted_row["eo"][modality] = self.load_s2(row["s2_path"]) # TODO: augmentations @@ -160,7 +180,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: if self.use_target_data: formatted_row["target"] = torch.tensor( - [row[k] for k in self.target_names], dtype=torch.float32 + [row[k] for k in self.target_names], dtype=self.dtype ) if self.use_aux_data: @@ -168,7 +188,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: for aux_cat, vals in self.use_aux_data.items(): if aux_cat == "aux": formatted_row["aux"][aux_cat] = torch.tensor( - [row[v] for v in vals], dtype=torch.float32 + [row[v] for v in vals], dtype=self.dtype ) else: formatted_row["aux"][aux_cat] = [row[v] for v in vals] diff --git a/src/models/components/pred_heads/base_pred_head.py b/src/models/components/pred_heads/base_pred_head.py index b3acd0c..02e2af6 100644 --- a/src/models/components/pred_heads/base_pred_head.py +++ b/src/models/components/pred_heads/base_pred_head.py @@ -59,3 +59,21 @@ def setup(self) -> List[str]: def _setup(self) -> None: """Configures specific prediction head.""" pass + + @property + def device(self) -> torch.device | None: + devices = {p.device for p in self.parameters()} + if len(devices) > 1: + raise RuntimeError("Prediction head is on multiple devices") + elif len(devices) == 0: + return None + return devices.pop() + + @property + def dtype(self) -> torch.dtype | None: + dtypes = {p.dtype for p in self.parameters()} + if len(dtypes) > 1: + raise RuntimeError("Prediction head has multiple dtypes") + elif len(dtypes) == 0: + return None + return dtypes.pop() \ No newline at end of file diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index 9d8ab9e..9e3c28a 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -112,7 +112,7 @@ def _setup_encoders_adapters(self): self.trainable_modules.extend(new_modules) if self.normalize_features: - self.normalizer = nn.LayerNorm(self.geo_encoder.output_dim) + self.normalizer = nn.LayerNorm(self.geo_encoder.output_dim, dtype=self.geo_encoder.dtype) self.trainable_modules.append("normalizer") print("Model set up to normalise geo_encoder features.") @@ -121,6 +121,9 @@ def _setup_encoders_adapters(self): input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes ) self.prediction_head.setup() + if self.prediction_head.dtype != self.geo_encoder.dtype: + self.prediction_head = self.prediction_head.to(dtype=self.geo_encoder.dtype) + if "prediction_head" not in self.trainable_modules: self.trainable_modules.append("prediction_head")