Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 84 additions & 19 deletions src/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@

import numpy as np
import pandas as pd
import rasterio
import torch
from torch.utils.data import Dataset

import src.data_preprocessing.data_utils as du
from src.utils.data_utils import center_crop_npy

TORCH_DTYPES = {
"float32": torch.float32,
"float64": torch.float64,
"int32": torch.int32,
"int64": torch.int64,
"bfloat16": torch.bfloat16,
}


class BaseDataset(Dataset, ABC):
def __init__(
Expand All @@ -27,6 +35,7 @@ def __init__(
mock: bool = False,
use_features: bool = True,
csv_name: str = None,
dtype: str = "float32",
) -> None:
"""Interface for any use case dataset.

Expand All @@ -50,18 +59,34 @@ def __init__(
:param implemented_mod: implemented modalities for each dataset
:param mock: whether to mock csv file
:param use_features: if tabular feat_* columns should be included. Default True.
:param dtype: global dtype (used if not specified for each modality individually), also used for aux, target
"""

if mock:
dataset_name = "mock"

# Dtype
assert dtype in TORCH_DTYPES.keys()
self.dtype: str = TORCH_DTYPES[dtype]

# Modalities
self.implemented_mod = implemented_mod
self.modalities: dict = modalities
for mod in self.modalities.keys():

# Check modalities and set dtypes
for mod, configs in self.modalities.items():
if mod not in self.implemented_mod:
raise ValueError(f"{mod} not in implemented modalities.")
# more precise dataset name (with modalities)

if configs is not None:
m_dtype = configs.get("dtype", dtype)
self.modalities[mod]["dtype"] = m_dtype
print(f"Dtype of {mod} set to {m_dtype}")
else:
m_dtype = dtype
self.modalities[mod] = {"dtype": m_dtype}

# More precise dataset name (with modalities)
self.dataset_name: str = dataset_name + "_" + "_".join(modalities)

# Set data attributes
Expand Down Expand Up @@ -288,35 +313,75 @@ def pooch_setup(self) -> None:
self.pooch_cli.load_registry(self.registry_path)

@final
def load_npy(self, filepath: str) -> torch.Tensor:
def load_npy(self, filepath: str, dtype: np.dtype) -> np.ndarray:
"""Loads numpy array from file as a tensor."""
im = np.load(filepath).transpose(2, 0, 1)
return torch.from_numpy(im).float()
arr = np.load(filepath).transpose(2, 0, 1)
if arr.dtype != np.dtype(dtype):
arr = arr.astype(dtype=dtype, copy=False)

return arr

@final
def load_tiff(self, tiff_file_path: str, dtype: np.dtype) -> np.ndarray:
"""Load tiff file as np array of a specified dtype."""

with rasterio.open(tiff_file_path) as f:
im = f.read()
assert isinstance(im, np.ndarray)
if im.dtype != np.dtype(dtype):
im = im.astype(dtype=dtype, copy=False)
return im

@final
def load_aef(self, filepath: str):
"""Loads AEF data from file as a tensor."""

im = du.load_tiff(filepath, datatype="np")
# Modality settings
size = self.modalities["aef"]["size"]
if im.shape[1] != size:
dtype = self.modalities["aef"].get("dtype")
dtype, is_bfloat16 = self.resolve_dtype(dtype)

im = self.load_tiff(filepath, np.dtype(dtype))

if im.shape[-2:] != (size, size):
im = center_crop_npy(im, (64, size, size))
if np.isinf(im).any():
im = np.clip(im, a_min=-0.5, a_max=0.5)
# TODO any normalisation needed

return torch.tensor(im).float()
# Scan for inf values and clip them (in memory)
np.clip(im, -0.5, 0.5, out=im)
# TODO any other normalisation needed

tensor = torch.from_numpy(im)
if is_bfloat16:
tensor = tensor.to(torch.bfloat16)
return tensor

@final
def load_tessera(self, filepath: str) -> torch.Tensor:
"""Loads."""
size = self.modalities["tessera"]["size"]
arr = self.load_npy(filepath)
if arr.size()[1] < size:
raise ValueError(
f"Requested tile size {size} is larger than actual available tile size {arr.size()[1]}"
)
elif arr.size()[1] != size:
dtype = self.modalities["tessera"]["dtype"]
dtype, is_bfloat16 = self.resolve_dtype(dtype)

arr = self.load_npy(filepath, np.dtype(dtype))

if arr.shape[-2:] != (size, size):
arr = center_crop_npy(arr, (128, size, size))

# Nans are 0 across all 128 channels
# mask = np.all(arr == 0, axis=0)
# arr[mask] = torch.nan
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of a design choice I guess. The NaNs will cause errors, the zeros will flow through silently and affect results. Is 0 a valid value for a tessera (or other EO embedding) dimension?

# TODO any normalisation needed
return arr

tensor = torch.from_numpy(arr)
if is_bfloat16:
tensor = tensor.to(torch.bfloat16)
return tensor

@staticmethod
def resolve_dtype(dtype_str: str):
"""Resolve dtype from string into numpy dtype and return flag for mixed precision dtype in
tensors."""
is_bfloat16 = dtype_str == "bfloat16"
np_dtype = np.float32 if is_bfloat16 else np.dtype(dtype_str)

return np_dtype, is_bfloat16
44 changes: 32 additions & 12 deletions src/data/butterfly_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pooch
import torch

import src.data_preprocessing.data_utils as du
from src.data.base_dataset import BaseDataset
from src.data.base_dataset import BaseDataset, TORCH_DTYPES
from src.data_preprocessing.renaming_utils import rename_s2bms
from src.utils.data_utils import center_crop_npy
from src.utils.errors import IllegalArgumentCombination


Expand All @@ -21,6 +21,7 @@ def __init__(
seed: int = 12345,
cache_dir: str = None,
mock: bool = False,
dtype: str = "float32",
) -> None:
"""A dataset implementation for the Butterfly diversity use case.

Expand All @@ -32,6 +33,8 @@ def __init__(
:param seed: random seed
:param cache_dir: path to cache dir
:param mock: whether to mock csv file
:param dtype: global dtype (used if not specified for each modality individually), also used for aux, target

"""

super().__init__(
Expand All @@ -44,6 +47,7 @@ def __init__(
cache_dir=cache_dir,
implemented_mod={"s2", "tessera", "coords", "aef"},
mock=mock,
dtype=dtype
)

def setup(self):
Expand Down Expand Up @@ -122,24 +126,40 @@ def zscore_image(self, im: np.ndarray):
return im

def load_s2(self, filepath: str):
im = du.load_tiff(filepath, datatype="np")
"""Loads s2 image tile from file as a tensor."""

# Modality settings
size = self.modalities["s2"]["size"]
np_dtype, is_bfloat16 = self.resolve_dtype(self.modalities["s2"]["dtype"])

if self.modalities["s2"]["channels"] == "4c":
pass
elif self.modalities["s2"]["channels"] == "rgb":
im = self.load_tiff(filepath, dtype=np.dtype('uint16'))
if self.modalities["s2"].get("channels", '') == "4c":
c = 4
elif self.modalities["s2"].egt("channels", '') == "rgb":
im = im[:3, :, :]
c = 3
else:
raise IllegalArgumentCombination(
f"Channel specification {self.n_bands} is not implemented."
f"Channel specification {self.modalities["s2"].get("channels", 'null')} is not implemented."
)

if self.modalities["s2"]["preprocessing"] == "zscored":
if self.modalities["s2"].get("preprocessing") == "zscored":
im = im.astype(np.int32)
im = self.zscore_image(im)
else:
im = np.clip(im, 0, 2000)
im = im / 2000.0
return torch.tensor(im).float()

im = im.astype(dtype=np_dtype)

# Crop
if im.shape[-2:] != (size, size):
im = center_crop_npy(im, (c, size, size))

tensor = torch.from_numpy(im)
if is_bfloat16:
tensor = tensor.to(torch.bfloat16)
return tensor

@override
def __getitem__(self, idx: int) -> Dict[str, Any]:
Expand All @@ -149,7 +169,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:

for modality in self.modalities:
if modality in ["coords"]:
formatted_row["eo"][modality] = torch.tensor([row["lat"], row["lon"]])
formatted_row["eo"][modality] = torch.tensor([row["lat"], row["lon"]], dtype=TORCH_DTYPES[self.modalities[modality]['dtype']])
elif modality == "s2":
formatted_row["eo"][modality] = self.load_s2(row["s2_path"])
# TODO: augmentations
Expand All @@ -160,15 +180,15 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:

if self.use_target_data:
formatted_row["target"] = torch.tensor(
[row[k] for k in self.target_names], dtype=torch.float32
[row[k] for k in self.target_names], dtype=self.dtype
)

if self.use_aux_data:
formatted_row["aux"] = {}
for aux_cat, vals in self.use_aux_data.items():
if aux_cat == "aux":
formatted_row["aux"][aux_cat] = torch.tensor(
[row[v] for v in vals], dtype=torch.float32
[row[v] for v in vals], dtype=self.dtype
)
else:
formatted_row["aux"][aux_cat] = [row[v] for v in vals]
Expand Down
18 changes: 18 additions & 0 deletions src/models/components/pred_heads/base_pred_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,21 @@ def setup(self) -> List[str]:
def _setup(self) -> None:
"""Configures specific prediction head."""
pass

@property
def device(self) -> torch.device | None:
devices = {p.device for p in self.parameters()}
if len(devices) > 1:
raise RuntimeError("Prediction head is on multiple devices")
elif len(devices) == 0:
return None
return devices.pop()

@property
def dtype(self) -> torch.dtype | None:
dtypes = {p.dtype for p in self.parameters()}
if len(dtypes) > 1:
raise RuntimeError("Prediction head has multiple dtypes")
elif len(dtypes) == 0:
return None
return dtypes.pop()
5 changes: 4 additions & 1 deletion src/models/predictive_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _setup_encoders_adapters(self):
self.trainable_modules.extend(new_modules)

if self.normalize_features:
self.normalizer = nn.LayerNorm(self.geo_encoder.output_dim)
self.normalizer = nn.LayerNorm(self.geo_encoder.output_dim, dtype=self.geo_encoder.dtype)
self.trainable_modules.append("normalizer")
print("Model set up to normalise geo_encoder features.")

Expand All @@ -121,6 +121,9 @@ def _setup_encoders_adapters(self):
input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes
)
self.prediction_head.setup()
if self.prediction_head.dtype != self.geo_encoder.dtype:
self.prediction_head = self.prediction_head.to(dtype=self.geo_encoder.dtype)

if "prediction_head" not in self.trainable_modules:
self.trainable_modules.append("prediction_head")

Expand Down
Loading