diff --git a/README.md b/README.md index 46dbcd14..3d9ed650 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,7 @@ open_geotiff('dem.tif', dtype='float32') # half memory open_geotiff('dem.tif', dtype='float32', chunks=512) # Dask + half memory to_geotiff(data, 'out.tif', compression_level=1) # fast scratch write to_geotiff(data, 'out.tif', compression_level=22) # max compression +to_geotiff(dask_da, 'out.tif') # stream Dask to single TIFF to_geotiff(dask_da, 'mosaic.vrt') # stream Dask to VRT # Accessor methods diff --git a/examples/user_guide/47_Streaming_GeoTIFF_Write.ipynb b/examples/user_guide/47_Streaming_GeoTIFF_Write.ipynb new file mode 100644 index 00000000..fd00ea97 --- /dev/null +++ b/examples/user_guide/47_Streaming_GeoTIFF_Write.ipynb @@ -0,0 +1,257 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Streaming GeoTIFF write from dask arrays\n", + "\n", + "When you call `to_geotiff()` on a dask-backed DataArray, the data is written one tile-row at a time. Only one tile-row lives in memory at once, so you can write rasters larger than RAM without switching to VRT output.\n", + "\n", + "This notebook shows the three write modes for dask data:\n", + "1. **Streaming to a single TIFF** (automatic when the input is dask-backed)\n", + "2. **Streaming to a VRT** (one file per chunk, stitched by an XML index)\n", + "3. **Eager write** (materialise first, then write; needed for COG with overviews)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "import tempfile\n", + "import os\n", + "\n", + "import numpy as np\n", + "import xarray as xr\n", + "import dask.array as da\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from xrspatial.geotiff import open_geotiff, to_geotiff" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build a dask-backed raster\n", + "\n", + "A 2000x2000 terrain surface chunked into 500x500 blocks. Four chunks along each axis, sixteen chunks total." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rng = np.random.default_rng(1084)\n", + "H, W = 2000, 2000\n", + "\n", + "yy, xx = np.meshgrid(\n", + " np.linspace(0, 6 * np.pi, H),\n", + " np.linspace(0, 6 * np.pi, W),\n", + " indexing='ij',\n", + ")\n", + "terrain = (500 + 200 * np.sin(yy) * np.cos(xx * 0.7)\n", + " + 30 * rng.standard_normal((H, W))).astype(np.float32)\n", + "\n", + "y = np.linspace(45.0, 44.0, H)\n", + "x = np.linspace(-122.0, -121.0, W)\n", + "\n", + "raster = xr.DataArray(\n", + " terrain, dims=['y', 'x'],\n", + " coords={'y': y, 'x': x},\n", + " attrs={'crs': 4326, 'nodata': -9999.0},\n", + ")\n", + "\n", + "dask_raster = raster.chunk({'y': 500, 'x': 500})\n", + "print(f'Shape: {dask_raster.shape}')\n", + "print(f'Chunks: {dask_raster.chunks}')\n", + "print(f'dtype: {dask_raster.dtype}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(6, 6))\n", + "raster.plot.imshow(ax=ax, cmap='terrain', add_colorbar=True)\n", + "ax.set_title('Synthetic terrain (2000x2000)')\n", + "ax.set_axis_off()\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Streaming write to a single TIFF\n", + "\n", + "Pass the dask-backed DataArray to `to_geotiff()` the same way you would a numpy array. The streaming path kicks in automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tmpdir = tempfile.mkdtemp(prefix='xrs_stream_nb_')\n", + "\n", + "tif_path = os.path.join(tmpdir, 'streamed.tif')\n", + "to_geotiff(dask_raster, tif_path)\n", + "\n", + "print(f'File size: {os.path.getsize(tif_path):,} bytes')\n", + "\n", + "# Read back and verify\n", + "loaded = open_geotiff(tif_path)\n", + "print(f'Shape: {loaded.shape}')\n", + "print(f'CRS: {loaded.attrs.get(\"crs\")}')\n", + "print(f'Match: {np.allclose(loaded.values, raster.values)}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "That's it. Same API, same output, but peak memory was roughly `tile_size * width * 4 bytes` instead of the full 2000x2000 array." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Streaming write to a VRT\n", + "\n", + "If you want one tile per dask chunk (useful when chunks are large or you plan to read subregions later), write to a `.vrt` path instead." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vrt_path = os.path.join(tmpdir, 'tiled.vrt')\n", + "to_geotiff(dask_raster, vrt_path)\n", + "\n", + "tiles_dir = os.path.join(tmpdir, 'tiled_tiles')\n", + "tile_files = sorted(os.listdir(tiles_dir))\n", + "print(f'VRT file: {os.path.getsize(vrt_path):,} bytes')\n", + "print(f'Tile count: {len(tile_files)}')\n", + "print(f'Tiles: {tile_files}')\n", + "\n", + "mosaic = open_geotiff(vrt_path)\n", + "print(f'\\nMosaic shape: {mosaic.shape}')\n", + "print(f'Match: {np.allclose(mosaic.values, raster.values)}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Four chunks along each axis produces 16 tile files, stitched by a lightweight XML index." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Compression and layout options\n", + "\n", + "All `to_geotiff` options work with the streaming path. Try different codecs and see the file size difference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "codecs = ['none', 'deflate', 'zstd', 'lzw']\n", + "sizes = {}\n", + "\n", + "for codec in codecs:\n", + " p = os.path.join(tmpdir, f'stream_{codec}.tif')\n", + " to_geotiff(dask_raster, p, compression=codec)\n", + " sizes[codec] = os.path.getsize(p)\n", + "\n", + "for codec, sz in sizes.items():\n", + " ratio = sz / sizes['none']\n", + " print(f'{codec:>8s}: {sz:>12,} bytes ({ratio:.2%} of uncompressed)')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. When streaming doesn't apply\n", + "\n", + "COG output with `cog=True` needs overviews, which are built from the full array. In that case `to_geotiff` falls through to the eager path and calls `.compute()` as before." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cog_path = os.path.join(tmpdir, 'eager_cog.tif')\n", + "to_geotiff(dask_raster, cog_path, cog=True)\n", + "\n", + "print(f'COG size: {os.path.getsize(cog_path):,} bytes')\n", + "cog = open_geotiff(cog_path)\n", + "print(f'Match: {np.allclose(cog.values, raster.values)}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If the full array doesn't fit in memory, use VRT output instead of COG." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "shutil.rmtree(tmpdir, ignore_errors=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Summary\n", + "\n", + "| Write mode | Path | Peak memory | When to use |\n", + "|:-----------|:-----|:------------|:------------|\n", + "| Streaming TIFF | `out.tif` | ~1 tile-row | Default for dask input |\n", + "| Streaming VRT | `out.vrt` | ~1 chunk | Need per-chunk files |\n", + "| Eager (COG) | `out.tif`, `cog=True` | Full array | Need overviews |" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index c0fe3044..2fe46115 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -387,6 +387,12 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, gpu: bool | None = None) -> None: """Write data as a GeoTIFF or Cloud Optimized GeoTIFF. + Dask-backed DataArrays are written in streaming mode: one tile-row + at a time, without materialising the full array into RAM. Peak + memory is roughly ``tile_size * width * bytes_per_sample``. COG + output (``cog=True``) still materialises because overviews need the + full array. + Automatically dispatches to GPU compression when: - ``gpu=True`` is passed, or - The input data is CuPy-backed (auto-detected) @@ -483,25 +489,14 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, wkt_fallback = crs if isinstance(data, xr.DataArray): - # Handle CuPy-backed DataArrays: convert to numpy for CPU write raw = data.data - if hasattr(raw, 'get'): - arr = raw.get() # CuPy -> numpy - elif hasattr(raw, 'compute'): - arr = raw.compute() # Dask -> numpy - if hasattr(arr, 'get'): - arr = arr.get() # Dask+CuPy -> numpy - else: - arr = np.asarray(raw) - # Handle band-first dimension order (band, y, x) -> (y, x, band) - if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'): - arr = np.moveaxis(arr, 0, -1) + + # Extract metadata from DataArray attrs (no materialisation needed) if geo_transform is None: geo_transform = _coords_to_transform(data) if epsg is None and crs is None: crs_attr = data.attrs.get('crs') if isinstance(crs_attr, str): - # WKT string from reproject() or other source epsg = _wkt_to_epsg(crs_attr) if epsg is None and wkt_fallback is None: wkt_fallback = crs_attr @@ -517,22 +512,75 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, nodata = data.attrs.get('nodata') if data.attrs.get('raster_type') == 'point': raster_type = RASTER_PIXEL_IS_POINT - # GDAL metadata from attrs (prefer raw XML, fall back to dict) gdal_meta_xml = data.attrs.get('gdal_metadata_xml') if gdal_meta_xml is None: gdal_meta_dict = data.attrs.get('gdal_metadata') if isinstance(gdal_meta_dict, dict): from ._geotags import _build_gdal_metadata_xml gdal_meta_xml = _build_gdal_metadata_xml(gdal_meta_dict) - # Extra tags for pass-through extra_tags_list = data.attrs.get('extra_tags') - # Resolution / DPI from attrs x_res = data.attrs.get('x_resolution') y_res = data.attrs.get('y_resolution') unit_str = data.attrs.get('resolution_unit') if unit_str is not None: _unit_ids = {'none': 1, 'inch': 2, 'centimeter': 3} res_unit = _unit_ids.get(str(unit_str), None) + + # Dask-backed: stream tiles to avoid materialising the full array. + # COG requires overviews from the full array, so it falls through + # to the eager path. + if hasattr(raw, 'dask') and not cog: + dask_arr = raw + # Handle band-first dimension order (band, y, x) -> (y, x, band) + if raw.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'): + import dask.array as da + dask_arr = da.moveaxis(raw, 0, -1) + if dask_arr.ndim not in (2, 3): + raise ValueError( + f"Expected 2D or 3D array, got {dask_arr.ndim}D") + # Validate compression_level + if compression_level is not None: + level_range = _LEVEL_RANGES.get(compression.lower()) + if level_range is not None: + lo, hi = level_range + if not (lo <= compression_level <= hi): + raise ValueError( + f"compression_level={compression_level} out of " + f"range for {compression} (valid: {lo}-{hi})") + from ._writer import write_streaming + write_streaming( + dask_arr, path, + geo_transform=geo_transform, + crs_epsg=epsg, + crs_wkt=wkt_fallback if epsg is None else None, + nodata=nodata, + compression=compression, + compression_level=compression_level, + tiled=tiled, + tile_size=tile_size, + predictor=predictor, + raster_type=raster_type, + x_resolution=x_res, + y_resolution=y_res, + resolution_unit=res_unit, + gdal_metadata_xml=gdal_meta_xml, + extra_tags=extra_tags_list, + bigtiff=bigtiff, + ) + return + + # Eager compute (numpy, CuPy, or dask+COG) + if hasattr(raw, 'get'): + arr = raw.get() # CuPy -> numpy + elif hasattr(raw, 'compute'): + arr = raw.compute() # Dask -> numpy + if hasattr(arr, 'get'): + arr = arr.get() # Dask+CuPy -> numpy + else: + arr = np.asarray(raw) + # Handle band-first dimension order (band, y, x) -> (y, x, band) + if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'): + arr = np.moveaxis(arr, 0, -1) else: if hasattr(data, 'get'): arr = data.get() # CuPy -> numpy diff --git a/xrspatial/geotiff/_writer.py b/xrspatial/geotiff/_writer.py index c0fa5133..0b430b50 100644 --- a/xrspatial/geotiff/_writer.py +++ b/xrspatial/geotiff/_writer.py @@ -983,6 +983,373 @@ def write(data: np.ndarray, path: str, *, warnings.warn(f"Written file may be corrupt: {e}", stacklevel=2) +def _compress_block(arr, block_w, block_h, samples, dtype, bytes_per_sample, + predictor, compression, compression_level=None): + """Compress a tile or strip. *arr* must be contiguous and correctly sized.""" + if compression == COMPRESSION_JPEG: + return jpeg_compress(arr.tobytes(), block_w, block_h, samples) + + if predictor and compression != COMPRESSION_NONE: + buf = arr.view(np.uint8).ravel().copy() + buf = predictor_encode(buf, block_w, block_h, + bytes_per_sample * samples) + raw_data = buf.tobytes() + else: + raw_data = arr.tobytes() + + if compression == COMPRESSION_JPEG2000: + from ._compression import jpeg2000_compress + return jpeg2000_compress(raw_data, block_w, block_h, + samples=samples, dtype=dtype) + if compression == COMPRESSION_LERC: + from ._compression import lerc_compress + return lerc_compress(raw_data, block_w, block_h, + samples=samples, dtype=dtype) + if compression_level is None: + return compress(raw_data, compression) + return compress(raw_data, compression, level=compression_level) + + +# --------------------------------------------------------------------------- +# Streaming writer (dask -> monolithic TIFF without full materialisation) +# --------------------------------------------------------------------------- + +def write_streaming(dask_data, path: str, *, + geo_transform: 'GeoTransform | None' = None, + crs_epsg: int | None = None, + crs_wkt: str | None = None, + nodata=None, + compression: str = 'zstd', + compression_level: int | None = None, + tiled: bool = True, + tile_size: int = 256, + predictor: bool = False, + raster_type: int = 1, + x_resolution: float | None = None, + y_resolution: float | None = None, + resolution_unit: int | None = None, + gdal_metadata_xml: str | None = None, + extra_tags: list | None = None, + bigtiff: bool | None = None) -> None: + """Write a dask array as a GeoTIFF by streaming one tile-row at a time. + + Peak memory is approximately ``tile_height * width * bytes_per_sample`` + for tiled output, or ``rows_per_strip * width * bytes_per_sample`` for + stripped output. + + After all pixel data is written the IFD offset and byte-count arrays + are patched in place. + """ + import os + import tempfile + + # Fail fast for unsupported destinations + if _is_fsspec_uri(path): + raise NotImplementedError( + "Streaming dask write to cloud storage is not yet supported. " + "Use .compute() first or write to a .vrt file.") + + height, width = dask_data.shape[:2] + samples = dask_data.shape[2] if dask_data.ndim == 3 else 1 + dtype = dask_data.dtype + + # Match the eager path's dtype promotion + out_dtype = dtype + if out_dtype == np.float16: + out_dtype = np.float32 + elif out_dtype == np.bool_: + out_dtype = np.uint8 + + bits_per_sample, sample_format = numpy_to_tiff_dtype(out_dtype) + bytes_per_sample = out_dtype.itemsize + comp_tag = _compression_tag(compression) + + if comp_tag == COMPRESSION_JPEG: + if out_dtype != np.uint8: + raise ValueError( + f"JPEG compression requires uint8 data, got {out_dtype}.") + if samples not in (1, 3): + raise ValueError( + f"JPEG compression requires 1 or 3 bands, got {samples}") + + # Layout parameters + if tiled: + tw = th = tile_size + tiles_across = math.ceil(width / tw) + tiles_down = math.ceil(height / th) + n_entries = tiles_across * tiles_down + else: + rows_per_strip = min(256, height) + n_entries = math.ceil(height / rows_per_strip) + + # BigTIFF detection (use uncompressed size as conservative estimate) + uncompressed_bytes = height * width * bytes_per_sample * samples + UINT32_MAX = 0xFFFFFFFF + if bigtiff is not None: + use_bigtiff = bigtiff + else: + use_bigtiff = uncompressed_bytes > UINT32_MAX + + header_size = 16 if use_bigtiff else 8 + + # ---- Build tag list (mirrors _assemble_tiff for level 0) ---- + tags = [] + tags.append((TAG_IMAGE_WIDTH, LONG, 1, width)) + tags.append((TAG_IMAGE_LENGTH, LONG, 1, height)) + if samples > 1: + tags.append((TAG_BITS_PER_SAMPLE, SHORT, samples, + [bits_per_sample] * samples)) + else: + tags.append((TAG_BITS_PER_SAMPLE, SHORT, 1, bits_per_sample)) + tags.append((TAG_COMPRESSION, SHORT, 1, comp_tag)) + photometric = 2 if samples >= 3 else 1 + tags.append((TAG_PHOTOMETRIC, SHORT, 1, photometric)) + tags.append((TAG_SAMPLES_PER_PIXEL, SHORT, 1, samples)) + if samples > 1: + tags.append((TAG_SAMPLE_FORMAT, SHORT, samples, + [sample_format] * samples)) + else: + tags.append((TAG_SAMPLE_FORMAT, SHORT, 1, sample_format)) + + if photometric == 2 and samples > 3: + n_extra = samples - 3 + extra_vals = [2] + [0] * (n_extra - 1) + tags.append((TAG_EXTRA_SAMPLES, SHORT, n_extra, extra_vals)) + elif photometric == 1 and samples > 1: + n_extra = samples - 1 + extra_vals = [0] * n_extra + tags.append((TAG_EXTRA_SAMPLES, SHORT, n_extra, extra_vals)) + + pred_val = 2 if (predictor and comp_tag != COMPRESSION_NONE) else 1 + if pred_val != 1: + tags.append((TAG_PREDICTOR, SHORT, 1, pred_val)) + + if x_resolution is not None: + tags.append((TAG_X_RESOLUTION, RATIONAL, 1, x_resolution)) + if y_resolution is not None: + tags.append((TAG_Y_RESOLUTION, RATIONAL, 1, y_resolution)) + if resolution_unit is not None: + tags.append((TAG_RESOLUTION_UNIT, SHORT, 1, resolution_unit)) + + # Layout tags with placeholder offsets / byte-counts. + # NOTE: offsets use TIFF type LONG (uint32). For BigTIFF files + # exceeding 4 GB these would need LONG8 -- same limitation as the + # eager writer. + placeholder = [0] * n_entries + if tiled: + tags.append((TAG_TILE_WIDTH, SHORT, 1, tile_size)) + tags.append((TAG_TILE_LENGTH, SHORT, 1, tile_size)) + tags.append((TAG_TILE_OFFSETS, LONG, n_entries, list(placeholder))) + tags.append((TAG_TILE_BYTE_COUNTS, LONG, n_entries, list(placeholder))) + else: + tags.append((TAG_ROWS_PER_STRIP, SHORT, 1, rows_per_strip)) + tags.append((TAG_STRIP_OFFSETS, LONG, n_entries, list(placeholder))) + tags.append((TAG_STRIP_BYTE_COUNTS, LONG, n_entries, list(placeholder))) + + # Geo tags + geo_tags_dict = {} + if geo_transform is not None: + geo_tags_dict = build_geo_tags( + geo_transform, crs_epsg, nodata, raster_type=raster_type, + crs_wkt=crs_wkt) + elif crs_epsg is not None or crs_wkt is not None or nodata is not None: + geo_tags_dict = build_geo_tags( + GeoTransform(), crs_epsg, nodata, raster_type=raster_type, + crs_wkt=crs_wkt) + geo_tags_dict.pop(TAG_MODEL_PIXEL_SCALE, None) + geo_tags_dict.pop(TAG_MODEL_TIEPOINT, None) + + for gtag, gval in geo_tags_dict.items(): + if gtag == TAG_MODEL_PIXEL_SCALE: + tags.append((gtag, DOUBLE, 3, list(gval))) + elif gtag == TAG_MODEL_TIEPOINT: + tags.append((gtag, DOUBLE, 6, list(gval))) + elif gtag == TAG_GEO_KEY_DIRECTORY: + tags.append((gtag, SHORT, len(gval), list(gval))) + elif gtag == TAG_GEO_ASCII_PARAMS: + tags.append((gtag, ASCII, len(str(gval)) + 1, str(gval))) + elif gtag == TAG_GDAL_NODATA: + tags.append((gtag, ASCII, len(str(gval)) + 1, str(gval))) + + if gdal_metadata_xml is not None: + tags.append((TAG_GDAL_METADATA, ASCII, + len(gdal_metadata_xml) + 1, gdal_metadata_xml)) + + if extra_tags is not None: + existing_ids = {t[0] for t in tags} + for etag_id, etype_id, ecount, evalue in extra_tags: + if etag_id not in existing_ids: + tags.append((etag_id, etype_id, ecount, evalue)) + + # ---- Pre-compute IFD reservation size ---- + sorted_tags = sorted(tags, key=lambda t: t[0]) + entry_size = 20 if use_bigtiff else 12 + count_size = 8 if use_bigtiff else 2 + next_size = 8 if use_bigtiff else 4 + num_tags = len(sorted_tags) + ifd_block_size = count_size + entry_size * num_tags + next_size + overflow_base = header_size + ifd_block_size + _, placeholder_overflow = _build_ifd(sorted_tags, overflow_base, + bigtiff=use_bigtiff) + pixel_data_start = overflow_base + len(placeholder_overflow) + + dir_name = os.path.dirname(os.path.abspath(path)) + os.makedirs(dir_name, exist_ok=True) + fd, tmp_path = tempfile.mkstemp(dir=dir_name, suffix='.tif.tmp') + + try: + # -- Pass 1: write header + placeholder IFD + streaming pixel data -- + actual_offsets = [] + actual_counts = [] + current_offset = pixel_data_start + + with os.fdopen(fd, 'wb') as f: + # Header + f.write(b'II') + if use_bigtiff: + f.write(struct.pack(f'{BO}H', 43)) + f.write(struct.pack(f'{BO}H', 8)) + f.write(struct.pack(f'{BO}H', 0)) + f.write(struct.pack(f'{BO}Q', header_size)) + else: + f.write(struct.pack(f'{BO}H', 42)) + f.write(struct.pack(f'{BO}I', header_size)) + + # Placeholder IFD + overflow + ifd_bytes, overflow_bytes = _build_ifd( + sorted_tags, overflow_base, bigtiff=use_bigtiff) + f.write(ifd_bytes) + f.write(overflow_bytes) + + # Stream pixel data + if tiled: + for tr in range(tiles_down): + r0 = tr * th + r1 = min(r0 + th, height) + actual_h = r1 - r0 + + # Compute one tile-row from the dask graph + if dask_data.ndim == 3: + row_np = np.asarray(dask_data[r0:r1, :, :].compute()) + else: + row_np = np.asarray(dask_data[r0:r1, :].compute()) + if hasattr(row_np, 'get'): + row_np = row_np.get() + + if row_np.dtype != out_dtype: + row_np = row_np.astype(out_dtype) + + # NaN -> nodata sentinel + if (nodata is not None and row_np.dtype.kind == 'f' + and not np.isnan(nodata)): + nan_mask = np.isnan(row_np) + if nan_mask.any(): + row_np = row_np.copy() + row_np[nan_mask] = row_np.dtype.type(nodata) + + for tc in range(tiles_across): + c0 = tc * tw + c1 = min(c0 + tw, width) + actual_w = c1 - c0 + + tile_slice = row_np[:, c0:c1] + + if actual_h < th or actual_w < tw: + if row_np.ndim == 3: + padded = np.zeros((th, tw, samples), + dtype=out_dtype) + else: + padded = np.zeros((th, tw), dtype=out_dtype) + padded[:actual_h, :actual_w] = tile_slice + tile_arr = padded + else: + tile_arr = np.ascontiguousarray(tile_slice) + + compressed = _compress_block( + tile_arr, tw, th, samples, out_dtype, + bytes_per_sample, predictor, comp_tag, + compression_level) + + actual_offsets.append(current_offset) + actual_counts.append(len(compressed)) + f.write(compressed) + current_offset += len(compressed) + + del row_np + else: + # Strip layout + for i in range(n_entries): + r0 = i * rows_per_strip + r1 = min(r0 + rows_per_strip, height) + strip_rows = r1 - r0 + + if dask_data.ndim == 3: + strip_np = np.asarray( + dask_data[r0:r1, :, :].compute()) + else: + strip_np = np.asarray(dask_data[r0:r1, :].compute()) + if hasattr(strip_np, 'get'): + strip_np = strip_np.get() + + if strip_np.dtype != out_dtype: + strip_np = strip_np.astype(out_dtype) + + if (nodata is not None and strip_np.dtype.kind == 'f' + and not np.isnan(nodata)): + nan_mask = np.isnan(strip_np) + if nan_mask.any(): + strip_np = strip_np.copy() + strip_np[nan_mask] = strip_np.dtype.type(nodata) + + compressed = _compress_block( + np.ascontiguousarray(strip_np), + width, strip_rows, samples, out_dtype, + bytes_per_sample, predictor, comp_tag, + compression_level) + + actual_offsets.append(current_offset) + actual_counts.append(len(compressed)) + f.write(compressed) + current_offset += len(compressed) + + del strip_np + + # -- Pass 2: patch IFD with actual offsets -- + patched_tags = [] + for tag_id, type_id, count, values in sorted_tags: + if tag_id in (TAG_TILE_OFFSETS, TAG_STRIP_OFFSETS): + patched_tags.append((tag_id, LONG, n_entries, actual_offsets)) + elif tag_id in (TAG_TILE_BYTE_COUNTS, TAG_STRIP_BYTE_COUNTS): + patched_tags.append((tag_id, LONG, n_entries, actual_counts)) + else: + patched_tags.append((tag_id, type_id, count, values)) + + with open(tmp_path, 'r+b') as f: + f.seek(header_size) + ifd_bytes, overflow_bytes = _build_ifd( + patched_tags, overflow_base, bigtiff=use_bigtiff) + f.write(ifd_bytes) + f.write(overflow_bytes) + + # Post-write validation + from ._header import parse_header as _ph + with open(tmp_path, 'rb') as f: + try: + _ph(f.read(16)) + except Exception as e: + import warnings + warnings.warn( + f"Written file may be corrupt: {e}", stacklevel=2) + + os.replace(tmp_path, path) + except BaseException: + try: + os.unlink(tmp_path) + except OSError: + pass + raise + + def _is_fsspec_uri(path: str) -> bool: """Check if a path is a fsspec-compatible URI.""" if path.startswith(('http://', 'https://')): diff --git a/xrspatial/geotiff/tests/test_streaming_write.py b/xrspatial/geotiff/tests/test_streaming_write.py new file mode 100644 index 00000000..328a4c40 --- /dev/null +++ b/xrspatial/geotiff/tests/test_streaming_write.py @@ -0,0 +1,263 @@ +"""Tests for streaming TIFF write from dask-backed DataArrays (#1084).""" +import numpy as np +import os +import pytest +import xarray as xr + +from xrspatial.geotiff import open_geotiff, to_geotiff + + +@pytest.fixture +def sample_raster(): + """200x200 float32 raster with coords and CRS.""" + arr = np.random.default_rng(1084).random((200, 200), dtype=np.float32) + y = np.linspace(41.0, 40.0, 200) + x = np.linspace(-106.0, -105.0, 200) + return xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326, 'nodata': -9999.0}) + + +@pytest.fixture +def dask_raster(sample_raster): + return sample_raster.chunk({'y': 100, 'x': 100}) + + +# -- Round-trip correctness -------------------------------------------------- + +class TestStreamingRoundTrip: + def test_tiled_zstd(self, sample_raster, dask_raster, tmp_path): + path = str(tmp_path / 'tiled_zstd_1084.tif') + to_geotiff(dask_raster, path, compression='zstd') + result = open_geotiff(path) + np.testing.assert_array_almost_equal( + result.values, sample_raster.values, decimal=5) + + def test_tiled_deflate(self, sample_raster, dask_raster, tmp_path): + path = str(tmp_path / 'tiled_deflate_1084.tif') + to_geotiff(dask_raster, path, compression='deflate') + result = open_geotiff(path) + np.testing.assert_array_almost_equal( + result.values, sample_raster.values, decimal=5) + + def test_tiled_lzw(self, sample_raster, dask_raster, tmp_path): + path = str(tmp_path / 'tiled_lzw_1084.tif') + to_geotiff(dask_raster, path, compression='lzw') + result = open_geotiff(path) + np.testing.assert_array_almost_equal( + result.values, sample_raster.values, decimal=5) + + def test_tiled_uncompressed(self, sample_raster, dask_raster, tmp_path): + path = str(tmp_path / 'tiled_none_1084.tif') + to_geotiff(dask_raster, path, compression='none') + result = open_geotiff(path) + np.testing.assert_array_almost_equal( + result.values, sample_raster.values, decimal=5) + + def test_stripped(self, sample_raster, dask_raster, tmp_path): + path = str(tmp_path / 'stripped_1084.tif') + to_geotiff(dask_raster, path, tiled=False) + result = open_geotiff(path) + np.testing.assert_array_almost_equal( + result.values, sample_raster.values, decimal=5) + + def test_predictor(self, sample_raster, dask_raster, tmp_path): + path = str(tmp_path / 'pred_1084.tif') + to_geotiff(dask_raster, path, predictor=True) + result = open_geotiff(path) + np.testing.assert_array_almost_equal( + result.values, sample_raster.values, decimal=5) + + def test_compression_level(self, sample_raster, dask_raster, tmp_path): + path = str(tmp_path / 'level_1084.tif') + to_geotiff(dask_raster, path, compression='deflate', + compression_level=1) + result = open_geotiff(path) + np.testing.assert_array_almost_equal( + result.values, sample_raster.values, decimal=5) + + def test_matches_eager_write(self, sample_raster, dask_raster, tmp_path): + """Streaming and eager paths should produce identical pixel data.""" + eager_path = str(tmp_path / 'eager_1084.tif') + stream_path = str(tmp_path / 'stream_1084.tif') + + to_geotiff(sample_raster, eager_path) # numpy -> eager + to_geotiff(dask_raster, stream_path) # dask -> streaming + + eager = open_geotiff(eager_path) + stream = open_geotiff(stream_path) + np.testing.assert_array_equal(eager.values, stream.values) + + +# -- Geo metadata preservation ----------------------------------------------- + +class TestStreamingGeoMetadata: + def test_crs_preserved(self, dask_raster, tmp_path): + path = str(tmp_path / 'crs_1084.tif') + to_geotiff(dask_raster, path) + result = open_geotiff(path) + assert result.attrs.get('crs') == 4326 + + def test_nodata_preserved(self, dask_raster, tmp_path): + path = str(tmp_path / 'nd_1084.tif') + to_geotiff(dask_raster, path) + result = open_geotiff(path) + assert float(result.attrs.get('nodata')) == pytest.approx(-9999.0) + + def test_coordinates_preserved(self, sample_raster, dask_raster, tmp_path): + path = str(tmp_path / 'coords_1084.tif') + to_geotiff(dask_raster, path) + result = open_geotiff(path) + np.testing.assert_array_almost_equal( + result.coords['x'].values, sample_raster.coords['x'].values, + decimal=6) + np.testing.assert_array_almost_equal( + result.coords['y'].values, sample_raster.coords['y'].values, + decimal=6) + + +# -- Edge cases --------------------------------------------------------------- + +class TestStreamingEdgeCases: + def test_nan_to_nodata(self, tmp_path): + """NaN pixels should round-trip through the nodata sentinel.""" + arr = np.ones((100, 100), dtype=np.float32) + arr[10:20, 10:20] = np.nan + da = xr.DataArray(arr, dims=['y', 'x'], + attrs={'nodata': -9999.0}) + dask_da = da.chunk({'y': 50, 'x': 50}) + + path = str(tmp_path / 'nan_1084.tif') + to_geotiff(dask_da, path) + result = open_geotiff(path) + + assert np.isnan(result.values[15, 15]) + assert result.values[0, 0] == pytest.approx(1.0) + + def test_single_chunk(self, sample_raster, tmp_path): + """Single chunk = whole array, but still goes through streaming.""" + dask_da = sample_raster.chunk({'y': 200, 'x': 200}) + path = str(tmp_path / 'single_1084.tif') + to_geotiff(dask_da, path) + result = open_geotiff(path) + np.testing.assert_array_almost_equal( + result.values, sample_raster.values, decimal=5) + + def test_uneven_chunks(self, tmp_path): + """Chunks that don't divide evenly into tile_size.""" + arr = np.arange(150 * 170, dtype=np.float32).reshape(150, 170) + da = xr.DataArray(arr, dims=['y', 'x']) + dask_da = da.chunk({'y': 64, 'x': 64}) + + path = str(tmp_path / 'uneven_1084.tif') + to_geotiff(dask_da, path, tile_size=128) + result = open_geotiff(path) + np.testing.assert_array_equal(result.values, arr) + + def test_small_raster(self, tmp_path): + """Raster smaller than one tile.""" + arr = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + da = xr.DataArray(arr, dims=['y', 'x']) + dask_da = da.chunk({'y': 2, 'x': 2}) + + path = str(tmp_path / 'tiny_1084.tif') + to_geotiff(dask_da, path, tile_size=256) + result = open_geotiff(path) + np.testing.assert_array_equal(result.values, arr) + + def test_uint16(self, tmp_path): + arr = np.arange(10000, dtype=np.uint16).reshape(100, 100) + da = xr.DataArray(arr, dims=['y', 'x']) + dask_da = da.chunk({'y': 50, 'x': 50}) + + path = str(tmp_path / 'u16_1084.tif') + to_geotiff(dask_da, path) + result = open_geotiff(path) + np.testing.assert_array_equal(result.values, arr) + + def test_int32(self, tmp_path): + arr = np.arange(10000, dtype=np.int32).reshape(100, 100) + da = xr.DataArray(arr, dims=['y', 'x']) + dask_da = da.chunk({'y': 50, 'x': 50}) + + path = str(tmp_path / 'i32_1084.tif') + to_geotiff(dask_da, path) + result = open_geotiff(path) + np.testing.assert_array_equal(result.values, arr) + + def test_float64(self, tmp_path): + arr = np.random.default_rng(1084).random((80, 80)) + da = xr.DataArray(arr, dims=['y', 'x']) + dask_da = da.chunk({'y': 40, 'x': 40}) + + path = str(tmp_path / 'f64_1084.tif') + to_geotiff(dask_da, path) + result = open_geotiff(path) + np.testing.assert_array_almost_equal(result.values, arr) + + +# -- Multiband ---------------------------------------------------------------- + +class TestStreamingMultiband: + def test_3d_band_last(self, tmp_path): + """3D array with (y, x, band) layout.""" + arr = np.random.default_rng(1084).random( + (100, 100, 3), dtype=np.float32) + da = xr.DataArray(arr, dims=['y', 'x', 'band']) + dask_da = da.chunk({'y': 50, 'x': 50}) + + path = str(tmp_path / 'band_last_1084.tif') + to_geotiff(dask_da, path) + result = open_geotiff(path) + np.testing.assert_array_almost_equal(result.values, arr, decimal=5) + + def test_3d_band_first(self, tmp_path): + """Band-first (band, y, x) DataArray gets transposed automatically.""" + arr = np.random.default_rng(1084).random( + (3, 100, 100), dtype=np.float32) + da = xr.DataArray(arr, dims=['band', 'y', 'x']) + dask_da = da.chunk({'y': 50, 'x': 50}) + + path = str(tmp_path / 'band_first_1084.tif') + to_geotiff(dask_da, path) + result = open_geotiff(path) + # Result is (y, x, band), so compare transposed + np.testing.assert_array_almost_equal( + result.values, np.moveaxis(arr, 0, -1), decimal=5) + + +# -- BigTIFF and error cases -------------------------------------------------- + +class TestStreamingBigTiffAndErrors: + def test_forced_bigtiff(self, tmp_path): + """bigtiff=True on a small array should produce a valid BigTIFF.""" + arr = np.arange(64, dtype=np.float32).reshape(8, 8) + da = xr.DataArray(arr, dims=['y', 'x']) + dask_da = da.chunk({'y': 4, 'x': 4}) + + path = str(tmp_path / 'bigtiff_1084.tif') + to_geotiff(dask_da, path, bigtiff=True) + result = open_geotiff(path) + np.testing.assert_array_equal(result.values, arr) + + def test_cloud_uri_raises(self, tmp_path): + """Streaming to cloud storage should raise NotImplementedError.""" + arr = np.ones((10, 10), dtype=np.float32) + da = xr.DataArray(arr, dims=['y', 'x']) + dask_da = da.chunk({'y': 5, 'x': 5}) + + with pytest.raises(NotImplementedError, match='cloud'): + to_geotiff(dask_da, 's3://bucket/file.tif') + + +# -- COG fallback to eager path ----------------------------------------------- + +class TestCogFallback: + def test_cog_with_dask_still_works(self, sample_raster, tmp_path): + """cog=True with dask input should fall through to eager compute.""" + dask_da = sample_raster.chunk({'y': 100, 'x': 100}) + path = str(tmp_path / 'cog_1084.tif') + to_geotiff(dask_da, path, cog=True) + result = open_geotiff(path) + np.testing.assert_array_almost_equal( + result.values, sample_raster.values, decimal=5)