Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/dolphin/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._background import *
from ._blocks import *
from ._core import *
from ._process import *
from ._readers import *
from ._writers import *
59 changes: 59 additions & 0 deletions src/dolphin/io/_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Protocol, Sequence

from numpy.typing import ArrayLike
from tqdm.auto import tqdm

from dolphin.utils import DummyProcessPoolExecutor

from ._blocks import iter_blocks
from ._readers import StackReader
from ._writers import DatasetWriter

__all__ = ["BlockProcessor", "process_blocks"]


class BlockProcessor(Protocol):
"""Protocol for a block-wise processing function.

Reads a block of data from each reader, processes it, and returns the result
as an array-like object.
"""

def __call__(
self, readers: Sequence[StackReader], rows: slice, cols: slice
) -> tuple[ArrayLike, slice, slice]:
...


def process_blocks(
readers: Sequence[StackReader],
writer: DatasetWriter,
func: BlockProcessor,
block_shape: tuple[int, int] = (512, 512),
num_threads: int = 5,
):
"""Perform block-wise processing over blocks in `readers`, writing to `writer`.

Used to read and process a stack of rasters in parallel, setting up a queue
of results for the `writer` to save.

Note that the parallelism happens using a `ThreadPoolExecutor`, so `func` should
be a function which releases the GIL during computation (e.g. using numpy).
"""
shape = readers[0].shape[-2:]
slices = list(iter_blocks(shape, block_shape=block_shape))

pbar = tqdm(total=len(slices))

# Define the callback to write the result to an output DatasetWrite
def write_callback(fut: Future):
data, rows, cols = fut.result()
writer[..., rows, cols] = data
pbar.update()

Executor = ThreadPoolExecutor if num_threads > 1 else DummyProcessPoolExecutor
with Executor(num_threads) as exc:
for rows, cols in slices:
future = exc.submit(func, readers=readers, rows=rows, cols=cols)
future.add_done_callback(write_callback)
20 changes: 13 additions & 7 deletions src/dolphin/io/_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numpy as np
import rasterio
import rasterio.errors
from numpy.typing import ArrayLike, DTypeLike
from rasterio.windows import Window

Expand Down Expand Up @@ -304,13 +305,18 @@ def __setitem__(self, key: tuple[Index, ...], value: np.ndarray, /) -> None:
elif len(key) == 3:
_, rows, cols = _unpack_3d_slices(key)
else:
raise ValueError(f"Invalid key: {key!r}")
window = Window.from_slices(
rows,
cols,
height=dataset.height,
width=dataset.width,
)
raise ValueError(
f"Invalid key for {self.__class__!r}.__setitem__: {key!r}"
)
try:
window = Window.from_slices(
rows,
cols,
height=dataset.height,
width=dataset.width,
)
except rasterio.errors.WindowError as e:
raise ValueError(f"Error creating window: {key = }, {value = }") from e
return dataset.write(value, self.band, window=window)


Expand Down
Loading