Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 40 additions & 36 deletions xrspatial/morphology.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,57 +132,63 @@ def _validate_kernel(kernel, func_name):

@ngjit
def _erode_kernel_numpy(data, kernel, rows, cols, ky, kx):
"""Erosion (local minimum) on a padded array."""
"""Erosion (local minimum) on a padded array.

Only kernel cells with non-zero entries contribute to the output.
The centre cell is included only when ``kernel[hy, hx]`` is non-zero.
NaN neighbours included by the kernel propagate to NaN. Cells where
the kernel covers no non-zero entries return NaN.
"""
out = np.empty((rows, cols), dtype=data.dtype)
hy = ky // 2
hx = kx // 2
for i in prange(rows):
for j in range(cols):
val = data[i + hy, j + hx]
if val != val: # NaN
out[i, j] = val
continue
mn = val
mn = np.nan
seen = False
for dy in range(ky):
for dx in range(kx):
if kernel[dy, dx] == 0:
continue
v = data[i + dy, j + dx]
if v != v: # NaN neighbour
if v != v: # NaN neighbour propagates
mn = v
seen = True
break
if v < mn:
if not seen or v < mn:
mn = v
if mn != mn: # propagate NaN
seen = True
if seen and mn != mn:
break
out[i, j] = mn
return out


@ngjit
def _dilate_kernel_numpy(data, kernel, rows, cols, ky, kx):
"""Dilation (local maximum) on a padded array."""
"""Dilation (local maximum) on a padded array.

Only kernel cells with non-zero entries contribute to the output.
The centre cell is included only when ``kernel[hy, hx]`` is non-zero.
NaN neighbours included by the kernel propagate to NaN. Cells where
the kernel covers no non-zero entries return NaN.
"""
out = np.empty((rows, cols), dtype=data.dtype)
hy = ky // 2
hx = kx // 2
for i in prange(rows):
for j in range(cols):
val = data[i + hy, j + hx]
if val != val: # NaN
out[i, j] = val
continue
mx = val
mx = np.nan
seen = False
for dy in range(ky):
for dx in range(kx):
if kernel[dy, dx] == 0:
continue
v = data[i + dy, j + dx]
if v != v: # NaN neighbour
if v != v: # NaN neighbour propagates
mx = v
seen = True
break
if v > mx:
if not seen or v > mx:
mx = v
if mx != mx:
seen = True
if seen and mx != mx:
break
out[i, j] = mx
return out
Expand Down Expand Up @@ -233,22 +239,21 @@ def _erode_gpu(data, kernel, out, hy, hx, ky, kx):
rows = out.shape[0]
cols = out.shape[1]
if i < rows and j < cols:
val = data[i + hy, j + hx]
if val != val:
out[i, j] = val
return
mn = val
mn = np.nan
seen = False
for dy in range(ky):
for dx in range(kx):
if kernel[dy, dx] == 0:
continue
v = data[i + dy, j + dx]
if v != v:
mn = v
seen = True
break
if v < mn:
if not seen or v < mn:
mn = v
if mn != mn:
seen = True
if seen and mn != mn:
break
out[i, j] = mn

Expand All @@ -259,22 +264,21 @@ def _dilate_gpu(data, kernel, out, hy, hx, ky, kx):
rows = out.shape[0]
cols = out.shape[1]
if i < rows and j < cols:
val = data[i + hy, j + hx]
if val != val:
out[i, j] = val
return
mx = val
mx = np.nan
seen = False
for dy in range(ky):
for dx in range(kx):
if kernel[dy, dx] == 0:
continue
v = data[i + dy, j + dx]
if v != v:
mx = v
seen = True
break
if v > mx:
if not seen or v > mx:
mx = v
if mx != mx:
seen = True
if seen and mx != mx:
break
out[i, j] = mx

Expand Down
98 changes: 98 additions & 0 deletions xrspatial/tests/test_morphology.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,104 @@ def test_erode_5x5_kernel():
assert r5.data[2, 2] == 0.0


# ---------------------------------------------------------------------------
# Centre-zero kernels (issue #1397)
# ---------------------------------------------------------------------------

# Cross-with-hole: centre cell is excluded from the structuring element.
_KERNEL_RING = np.array([
[1, 1, 1],
[1, 0, 1],
[1, 1, 1],
], dtype=np.uint8)


def test_erode_excludes_centre_when_kernel_centre_zero():
"""When kernel[centre]==0, the centre value must not contaminate erosion."""
data = np.full((5, 5), 5.0, dtype=np.float64)
data[2, 2] = 1.0 # centre lower than its neighbours
agg = create_test_raster(data)
result = morph_erode(agg, kernel=_KERNEL_RING, boundary='nearest')
# Centre is excluded; all 8 neighbours are 5.0, so min is 5.0.
assert result.data[2, 2] == 5.0


def test_dilate_excludes_centre_when_kernel_centre_zero():
"""When kernel[centre]==0, the centre value must not contaminate dilation."""
data = np.full((5, 5), 5.0, dtype=np.float64)
data[2, 2] = 100.0 # centre higher than its neighbours
agg = create_test_raster(data)
result = morph_dilate(agg, kernel=_KERNEL_RING, boundary='nearest')
assert result.data[2, 2] == 5.0


def test_erode_centre_zero_nan_centre_does_not_propagate():
"""A NaN at the centre must not propagate when kernel[centre]==0."""
data = np.full((5, 5), 5.0, dtype=np.float64)
data[2, 2] = np.nan
agg = create_test_raster(data)
result = morph_erode(agg, kernel=_KERNEL_RING, boundary='nearest')
# Centre NaN is excluded by kernel; neighbours are 5.0
assert result.data[2, 2] == 5.0


def test_dilate_centre_zero_nan_centre_does_not_propagate():
data = np.full((5, 5), 5.0, dtype=np.float64)
data[2, 2] = np.nan
agg = create_test_raster(data)
result = morph_dilate(agg, kernel=_KERNEL_RING, boundary='nearest')
assert result.data[2, 2] == 5.0


@dask_array_available
def test_erode_centre_zero_dask_matches_numpy():
"""Dask backend must agree with numpy for centre-zero kernels."""
data = np.full((10, 10), 5.0, dtype=np.float64)
data[5, 5] = 1.0
numpy_agg = create_test_raster(data, backend='numpy')
dask_agg = create_test_raster(data, backend='dask')
np_func = partial(morph_erode, kernel=_KERNEL_RING, boundary='nearest')
np_res = np_func(numpy_agg)
dk_res = np_func(dask_agg)
np.testing.assert_allclose(np_res.data, dk_res.data.compute())


@dask_array_available
def test_dilate_centre_zero_dask_matches_numpy():
data = np.full((10, 10), 5.0, dtype=np.float64)
data[5, 5] = 100.0
numpy_agg = create_test_raster(data, backend='numpy')
dask_agg = create_test_raster(data, backend='dask')
np_func = partial(morph_dilate, kernel=_KERNEL_RING, boundary='nearest')
np_res = np_func(numpy_agg)
dk_res = np_func(dask_agg)
np.testing.assert_allclose(np_res.data, dk_res.data.compute())


@cuda_and_cupy_available
def test_erode_centre_zero_cupy_matches_numpy():
data = np.full((10, 10), 5.0, dtype=np.float64)
data[5, 5] = 1.0
numpy_agg = create_test_raster(data, backend='numpy')
cupy_agg = create_test_raster(data, backend='cupy')
np_func = partial(morph_erode, kernel=_KERNEL_RING, boundary='nearest')
np_res = np_func(numpy_agg)
cp_res = np_func(cupy_agg)
np.testing.assert_allclose(np_res.data, cp_res.data.get())


@cuda_and_cupy_available
def test_dilate_centre_zero_cupy_matches_numpy():
data = np.full((10, 10), 5.0, dtype=np.float64)
data[5, 5] = 100.0
numpy_agg = create_test_raster(data, backend='numpy')
cupy_agg = create_test_raster(data, backend='cupy')
np_func = partial(morph_dilate, kernel=_KERNEL_RING, boundary='nearest')
np_res = np_func(numpy_agg)
cp_res = np_func(cupy_agg)
np.testing.assert_allclose(np_res.data, cp_res.data.get())


# ---------------------------------------------------------------------------
# Dataset support
# ---------------------------------------------------------------------------
Expand Down
Loading