Skip to content

Commit 31c8f77

Browse files
authored
geotiff: GPU + dask+GPU coverage for float16 read (#1941) (#1947)
* geotiff: GPU + dask+GPU backend coverage for float16 read (#1941) Issue #1941 added float16 auto-promotion on read and gated the GPU GDS path off for (bps=16, sf=float). The eager numpy and dask paths are covered by test_float16_read_1941.py; the cupy and dask+cupy paths had no targeted tests. A regression dropping the bps_mismatch fallback at _backends/gpu.py:357 or the float16 gate in _gds_chunk_path_available would silently mis-decode half-precision tiles and ship under existing CI. Adds 13 tests, all passing on a CUDA host: - read_geotiff_gpu on stripped + tiled (deflate, uncompressed) float16 - open_geotiff(gpu=True) dispatcher thread-through - windowed GPU reads on stripped + tiled float16 - open_geotiff(chunks=, gpu=True) and read_geotiff_gpu(chunks=) - _gds_chunk_path_available structural pin for (bps=16, sf=3) -> False plus a sanity check that float32 tiled files still allow GDS - cross-backend pixel-exact parity (numpy vs GPU, numpy vs dask+GPU, dask+numpy vs dask+GPU) - predictor=3 + float16 GPU round trip Mutation against bps_mismatch flipped 5 tests red; mutation against the GDS float16 gate flipped the structural test red. * geotiff: address PR #1947 review (kvikio gate, multi-tile fixture, importorskip ImportError)
1 parent b1579f8 commit 31c8f77

1 file changed

Lines changed: 341 additions & 0 deletions

File tree

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
"""GPU backend coverage for issue #1941 (float16 read).
2+
3+
#1941 added float16 auto-promotion on read by making
4+
``tiff_dtype_to_numpy(16, SAMPLE_FORMAT_FLOAT)`` return ``float32`` and
5+
adding the on-disk ``tiff_storage_dtype`` companion. The eager numpy and
6+
dask paths are covered by ``test_float16_read_1941.py``; this module
7+
closes the GPU and dask+GPU coverage gap.
8+
9+
A regression that:
10+
11+
* dropped the ``bps_mismatch`` stripped/odd-bps fallback at
12+
``_backends/gpu.py:357`` would route float16 stripped reads through
13+
the tiled GPU decoder and mis-decode the half-precision samples;
14+
* dropped the ``bps_first == 16 and sample_format == SAMPLE_FORMAT_FLOAT``
15+
early-out at ``_backends/gpu.py:791`` in ``_gds_chunk_path_available``
16+
would send tiled float16 chunked reads down the kvikIO GDS path and
17+
mis-stride the buffer;
18+
* dropped the entry at ``(16, SAMPLE_FORMAT_FLOAT) -> float32`` in
19+
``tiff_dtype_to_numpy`` would surface as ``ValueError("Unsupported
20+
BitsPerSample=16, SampleFormat=3")`` from the GPU read paths.
21+
22+
Every test ships through ``read_geotiff_gpu`` directly or through
23+
``open_geotiff(..., gpu=True)`` so the dispatcher path is also wired in.
24+
``cuda-unavailable`` builds skip the suite via the project's standard
25+
``CUDA_AVAILABLE`` gate.
26+
"""
27+
from __future__ import annotations
28+
29+
import importlib.util
30+
31+
import numpy as np
32+
import pytest
33+
import xarray as xr
34+
35+
36+
def _gpu_available() -> bool:
37+
if importlib.util.find_spec("cupy") is None:
38+
return False
39+
try:
40+
import cupy
41+
42+
return bool(cupy.cuda.is_available())
43+
except Exception:
44+
return False
45+
46+
47+
_HAS_GPU = _gpu_available()
48+
pytestmark = pytest.mark.skipif(
49+
not _HAS_GPU, reason="cupy + CUDA required for GPU float16 read tests",
50+
)
51+
52+
53+
@pytest.fixture
54+
def float16_stripped_tif(tmp_path):
55+
"""Stripped float16 GeoTIFF: triggers the bps_mismatch CPU fallback.
56+
57+
``tifffile.imwrite`` without ``tile=`` produces a stripped layout, so
58+
the GPU reader hits ``bps_mismatch=True`` (file_dtype.itemsize*8 == 32
59+
but bps == 16) and falls back to ``_read_to_array`` on CPU before
60+
copying to device.
61+
"""
62+
tifffile = pytest.importorskip("tifffile")
63+
arr = np.array(
64+
[[0.0, 1.0, 2.0, 3.0],
65+
[-1.0, -2.0, -3.0, -4.0],
66+
[0.5, 1.5, 2.5, 3.5],
67+
[100.0, 200.0, 300.0, 400.0]],
68+
dtype=np.float16,
69+
)
70+
path = tmp_path / "f16_stripped.tif"
71+
tifffile.imwrite(str(path), arr, compression=None)
72+
return path, arr
73+
74+
75+
@pytest.fixture
76+
def float16_tiled_tif(tmp_path):
77+
"""Multi-tile float16 GeoTIFF: 32x32 image, 16x16 tiles (2x2 grid).
78+
79+
Tiled and deflate-compressed. The 2x2 tile grid exercises inter-tile
80+
reassembly in the decoder path so a regression that mis-stitched
81+
adjacent tiles would surface here. ``bps_mismatch`` short-circuits
82+
the tiled GPU decode path and routes through the CPU decoder; the
83+
GDS path is also gated off via ``_gds_chunk_path_available``
84+
returning False for (bps=16, sf=3).
85+
"""
86+
tifffile = pytest.importorskip("tifffile")
87+
arr = np.arange(1024, dtype=np.float16).reshape(32, 32)
88+
path = tmp_path / "f16_tiled.tif"
89+
tifffile.imwrite(
90+
str(path), arr, compression="deflate", tile=(16, 16))
91+
return path, arr
92+
93+
94+
@pytest.fixture
95+
def float16_tiled_uncompressed_tif(tmp_path):
96+
"""Tiled uncompressed float16 GeoTIFF.
97+
98+
Mirrors ``float16_tiled_tif`` but with ``compression=None`` so the
99+
tile-decode path is exercised without an extra deflate codec call.
100+
Tile size 16 is the smallest tifffile allows.
101+
"""
102+
tifffile = pytest.importorskip("tifffile")
103+
arr = np.arange(256, dtype=np.float16).reshape(16, 16)
104+
path = tmp_path / "f16_tiled_none.tif"
105+
tifffile.imwrite(
106+
str(path), arr, compression=None, tile=(16, 16))
107+
return path, arr
108+
109+
110+
class TestEagerGPUReadFloat16:
111+
"""``read_geotiff_gpu`` returns float32 for stripped float16 input."""
112+
113+
def test_read_geotiff_gpu_stripped_returns_float32(
114+
self, float16_stripped_tif
115+
):
116+
from xrspatial.geotiff import read_geotiff_gpu
117+
118+
path, arr = float16_stripped_tif
119+
result = read_geotiff_gpu(str(path))
120+
assert result.dtype == np.float32, (
121+
f"GPU read of float16 must return float32, got {result.dtype}"
122+
)
123+
np.testing.assert_array_equal(
124+
result.data.get(), arr.astype(np.float32))
125+
126+
def test_read_geotiff_gpu_tiled_returns_float32(
127+
self, float16_tiled_tif
128+
):
129+
from xrspatial.geotiff import read_geotiff_gpu
130+
131+
path, arr = float16_tiled_tif
132+
result = read_geotiff_gpu(str(path))
133+
assert result.dtype == np.float32
134+
np.testing.assert_array_equal(
135+
result.data.get(), arr.astype(np.float32))
136+
137+
def test_read_geotiff_gpu_tiled_uncompressed_returns_float32(
138+
self, float16_tiled_uncompressed_tif
139+
):
140+
from xrspatial.geotiff import read_geotiff_gpu
141+
142+
path, arr = float16_tiled_uncompressed_tif
143+
result = read_geotiff_gpu(str(path))
144+
assert result.dtype == np.float32
145+
np.testing.assert_array_equal(
146+
result.data.get(), arr.astype(np.float32))
147+
148+
def test_open_geotiff_gpu_dispatcher_float16(self, float16_tiled_tif):
149+
"""``open_geotiff(gpu=True)`` dispatches correctly for float16."""
150+
from xrspatial.geotiff import open_geotiff
151+
152+
path, arr = float16_tiled_tif
153+
result = open_geotiff(str(path), gpu=True)
154+
assert result.dtype == np.float32
155+
np.testing.assert_array_equal(
156+
result.data.get(), arr.astype(np.float32))
157+
158+
159+
class TestGPUWindowedFloat16:
160+
"""Windowed GPU reads honour the bps_mismatch fallback path."""
161+
162+
def test_read_geotiff_gpu_windowed_stripped(self, float16_stripped_tif):
163+
from xrspatial.geotiff import read_geotiff_gpu
164+
165+
path, arr = float16_stripped_tif
166+
result = read_geotiff_gpu(str(path), window=(0, 0, 2, 2))
167+
assert result.dtype == np.float32
168+
assert result.shape == (2, 2)
169+
np.testing.assert_array_equal(
170+
result.data.get(), arr[:2, :2].astype(np.float32))
171+
172+
def test_read_geotiff_gpu_windowed_tiled(self, float16_tiled_tif):
173+
from xrspatial.geotiff import read_geotiff_gpu
174+
175+
path, arr = float16_tiled_tif
176+
result = read_geotiff_gpu(str(path), window=(0, 0, 8, 8))
177+
assert result.dtype == np.float32
178+
assert result.shape == (8, 8)
179+
np.testing.assert_array_equal(
180+
result.data.get(), arr[:8, :8].astype(np.float32))
181+
182+
183+
class TestDaskGPUFloat16:
184+
"""``open_geotiff(chunks=, gpu=True)`` decodes float16 correctly."""
185+
186+
def test_dask_gpu_tiled_float16(self, float16_tiled_tif):
187+
from xrspatial.geotiff import open_geotiff
188+
189+
path, arr = float16_tiled_tif
190+
result = open_geotiff(str(path), chunks=8, gpu=True)
191+
assert result.dtype == np.float32, (
192+
f"dask+GPU read of float16 must return float32, got {result.dtype}"
193+
)
194+
# Compute the dask array; under dask+cupy, .compute() yields a
195+
# cupy-backed DataArray, so the .data.get() step pulls to host.
196+
computed = result.compute()
197+
np.testing.assert_array_equal(
198+
computed.data.get(), arr.astype(np.float32))
199+
200+
def test_read_geotiff_gpu_chunks_kwarg_float16(self, float16_tiled_tif):
201+
"""``read_geotiff_gpu(chunks=)`` also routes correctly."""
202+
from xrspatial.geotiff import read_geotiff_gpu
203+
204+
path, arr = float16_tiled_tif
205+
result = read_geotiff_gpu(str(path), chunks=8)
206+
assert result.dtype == np.float32
207+
computed = result.compute()
208+
np.testing.assert_array_equal(
209+
computed.data.get(), arr.astype(np.float32))
210+
211+
212+
class TestGDSPathGatedOffForFloat16:
213+
"""``_gds_chunk_path_available`` returns False for (bps=16, sf=3).
214+
215+
Direct structural test of the gating logic added in #1941 to keep the
216+
KvikIO GDS chunked path from mis-decoding half-precision tiles. A
217+
regression dropping the float16 guard would silently corrupt every
218+
chunked GPU read of a float16 source.
219+
"""
220+
221+
def test_gds_path_gated_off_for_float16(self, float16_tiled_tif):
222+
pytest.importorskip("kvikio", exc_type=ImportError)
223+
224+
from xrspatial.geotiff._backends.gpu import _gds_chunk_path_available
225+
from xrspatial.geotiff._header import parse_all_ifds, parse_header
226+
227+
path, _ = float16_tiled_tif
228+
with open(str(path), "rb") as f:
229+
data = f.read()
230+
header = parse_header(data)
231+
ifds = parse_all_ifds(data, header)
232+
ifd = ifds[0]
233+
234+
# Sanity-check fixture: tiled, bps=16, sample_format=3 (float)
235+
from xrspatial.geotiff._dtypes import SAMPLE_FORMAT_FLOAT
236+
assert ifd.is_tiled, "fixture sanity: tiled layout expected"
237+
# Mirror the production unpacking pattern at gpu.py:791
238+
# (bps_first[0] if bps_first else 0) so an empty BitsPerSample
239+
# tuple would not raise IndexError here.
240+
bps_first = ifd.bits_per_sample
241+
if isinstance(bps_first, tuple):
242+
bps = bps_first[0] if bps_first else 0
243+
else:
244+
bps = bps_first
245+
assert bps == 16, "fixture sanity: bps=16 expected"
246+
assert ifd.sample_format == SAMPLE_FORMAT_FLOAT
247+
248+
result = _gds_chunk_path_available(
249+
str(path), ifd, has_sparse_tile=False, orientation=1)
250+
assert result is False, (
251+
"_gds_chunk_path_available must return False for "
252+
"(bps=16, sf=float) so the GDS chunked path does not "
253+
"mis-decode half-precision tiles."
254+
)
255+
256+
def test_gds_path_allowed_for_float32_tiled(self, tmp_path):
257+
"""Sanity: GDS path remains allowed for a float32 tiled file.
258+
259+
Pins that the float16 guard at gpu.py:791 fires only on
260+
(bps=16, sf=float), not on every tiled float file. A regression
261+
widening the guard to all floats would silently disable the
262+
GDS path on every float32 tiled COG.
263+
"""
264+
tifffile = pytest.importorskip("tifffile")
265+
pytest.importorskip("kvikio", exc_type=ImportError)
266+
267+
arr = np.arange(256, dtype=np.float32).reshape(16, 16)
268+
path = tmp_path / "f32_tiled.tif"
269+
tifffile.imwrite(
270+
str(path), arr, compression="deflate", tile=(16, 16))
271+
272+
from xrspatial.geotiff._backends.gpu import _gds_chunk_path_available
273+
from xrspatial.geotiff._header import parse_all_ifds, parse_header
274+
275+
with open(str(path), "rb") as f:
276+
data = f.read()
277+
header = parse_header(data)
278+
ifds = parse_all_ifds(data, header)
279+
280+
result = _gds_chunk_path_available(
281+
str(path), ifds[0], has_sparse_tile=False, orientation=1)
282+
assert result is True, (
283+
"_gds_chunk_path_available must remain True for "
284+
"(bps=32, sf=float) tiled files so the kvikio GDS chunk "
285+
"path still applies."
286+
)
287+
288+
289+
class TestBackendParityFloat16:
290+
"""All four backends agree pixel-exact on float16 input."""
291+
292+
def test_eager_numpy_equals_gpu(self, float16_tiled_tif):
293+
from xrspatial.geotiff import open_geotiff
294+
295+
path, _ = float16_tiled_tif
296+
cpu = open_geotiff(str(path))
297+
gpu = open_geotiff(str(path), gpu=True)
298+
299+
assert cpu.dtype == gpu.dtype == np.float32
300+
np.testing.assert_array_equal(np.asarray(cpu), gpu.data.get())
301+
302+
def test_eager_numpy_equals_dask_gpu(self, float16_tiled_tif):
303+
from xrspatial.geotiff import open_geotiff
304+
305+
path, _ = float16_tiled_tif
306+
cpu = open_geotiff(str(path))
307+
dask_gpu = open_geotiff(str(path), chunks=8, gpu=True).compute()
308+
309+
assert cpu.dtype == dask_gpu.dtype == np.float32
310+
np.testing.assert_array_equal(
311+
np.asarray(cpu), dask_gpu.data.get())
312+
313+
def test_dask_numpy_equals_dask_gpu(self, float16_tiled_tif):
314+
from xrspatial.geotiff import open_geotiff, read_geotiff_dask
315+
316+
path, _ = float16_tiled_tif
317+
dask_cpu = read_geotiff_dask(str(path), chunks=8).compute()
318+
dask_gpu = open_geotiff(str(path), chunks=8, gpu=True).compute()
319+
320+
np.testing.assert_array_equal(
321+
np.asarray(dask_cpu), dask_gpu.data.get())
322+
323+
324+
class TestPredictor3Float16GPU:
325+
"""Predictor=3 + float16 on disk also decodes correctly on GPU."""
326+
327+
def test_predictor3_float16_gpu_round_trip(self, tmp_path):
328+
tifffile = pytest.importorskip("tifffile")
329+
pytest.importorskip("imagecodecs") # required for predictor=3
330+
331+
from xrspatial.geotiff import read_geotiff_gpu
332+
333+
arr = np.linspace(-1.0, 1.0, 16).astype(np.float16).reshape(4, 4)
334+
path = tmp_path / "pred3_f16.tif"
335+
tifffile.imwrite(
336+
str(path), arr, predictor=3, compression="deflate")
337+
338+
result = read_geotiff_gpu(str(path))
339+
assert result.dtype == np.float32
340+
np.testing.assert_array_equal(
341+
result.data.get(), arr.astype(np.float32))

0 commit comments

Comments
 (0)