From 74bfdc166ed73ff88eb431f66a3578afd3ddbca2 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Mon, 30 Mar 2026 11:34:42 -0700 Subject: [PATCH 1/2] Fix three accuracy bugs in zonal stats dask backend (#1090) 1. Dask sum/count/min/max now return NaN (not 0) for zones with all-NaN values, matching the numpy backend. Uses _nanreduce_preserve_allnan wrapper around np.nansum/nanmax/nanmin. 2. Dask std/var replaced the naive one-pass formula with the Chan-Golub-LeVeque parallel merge algorithm, which avoids catastrophic cancellation when the mean is large relative to the variance. 3. _calc_stats and crosstab helpers now skip the nodata_values != comparison when nodata_values is None, avoiding numpy FutureWarning. --- xrspatial/zonal.py | 128 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 103 insertions(+), 25 deletions(-) diff --git a/xrspatial/zonal.py b/xrspatial/zonal.py index eaefaabd..8555c614 100644 --- a/xrspatial/zonal.py +++ b/xrspatial/zonal.py @@ -202,17 +202,76 @@ def _stats_majority(data): ) +def _nanreduce_preserve_allnan(blocks, func): + """Reduce across blocks, returning NaN when ALL blocks are NaN for a zone. + + ``np.nansum`` returns 0 for all-NaN input; we want NaN so that zones + with no valid values propagate NaN, consistent with the numpy backend. + """ + result = func(blocks, axis=0) + all_nan = np.all(np.isnan(blocks), axis=0) + result[all_nan] = np.nan + return result + + _DASK_STATS = dict( - max=lambda block_maxes: np.nanmax(block_maxes, axis=0), - min=lambda block_mins: np.nanmin(block_mins, axis=0), - sum=lambda block_sums: np.nansum(block_sums, axis=0), - count=lambda block_counts: np.nansum(block_counts, axis=0), - sum_squares=lambda block_sum_squares: np.nansum(block_sum_squares, axis=0), - squared_sum=lambda block_sums: np.nansum(block_sums, axis=0)**2, + max=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nanmax), + min=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nanmin), + sum=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nansum), + count=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nansum), + sum_squares=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nansum), ) -def _dask_mean(sums, counts): return sums / counts # noqa -def _dask_std(sum_squares, squared_sum, n): return np.sqrt((sum_squares - squared_sum/n) / n) # noqa -def _dask_var(sum_squares, squared_sum, n): return (sum_squares - squared_sum/n) / n # noqa + + +def _dask_mean(sums, counts): # noqa + return sums / counts + + +def _parallel_variance(block_counts, block_sums, block_sum_squares): + """Population variance via Chan-Golub-LeVeque parallel merge. + + Each input is (n_blocks, n_zones). Returns (n_zones,) variance, + with NaN for zones that have no valid values in any block. + + This avoids the naive ``(Σx² − (Σx)²/n) / n`` formula whose + subtraction can lose most significant digits when the mean is + large relative to the standard deviation. + """ + n_blocks = block_counts.shape[0] + n_zones = block_counts.shape[1] + + n_acc = np.zeros(n_zones, dtype=np.float64) + mean_acc = np.zeros(n_zones, dtype=np.float64) + m2_acc = np.zeros(n_zones, dtype=np.float64) + + for i in range(n_blocks): + nc = np.asarray(block_counts[i], dtype=np.float64) + sc = np.asarray(block_sums[i], dtype=np.float64) + sqc = np.asarray(block_sum_squares[i], dtype=np.float64) + + has_data = np.isfinite(nc) & (nc > 0) + nc_safe = np.where(has_data, nc, 1.0) # avoid /0 + + with np.errstate(invalid='ignore', divide='ignore'): + mean_b = sc / nc_safe + m2_b = sqc - sc ** 2 / nc_safe # block-internal M2 + + nc = np.where(has_data, nc, 0.0) + n_ab = n_acc + nc + + delta = mean_b - mean_acc + with np.errstate(invalid='ignore', divide='ignore'): + n_ab_safe = np.where(n_ab > 0, n_ab, 1.0) + correction = delta ** 2 * n_acc * nc / n_ab_safe + new_mean = mean_acc + delta * nc / n_ab_safe + + m2_acc = np.where(has_data, m2_acc + m2_b + correction, m2_acc) + mean_acc = np.where(has_data, new_mean, mean_acc) + n_acc = np.where(has_data, n_ab, n_acc) + + with np.errstate(invalid='ignore', divide='ignore'): + var = np.where(n_acc > 0, m2_acc / n_acc, np.nan) + return var @ngjit @@ -269,7 +328,10 @@ def _calc_stats( if unique_zones[i] in zone_ids: zone_values = values_by_zones[start:end] # filter out non-finite and nodata_values - zone_values = zone_values[np.isfinite(zone_values) & (zone_values != nodata_values)] + mask = np.isfinite(zone_values) + if nodata_values is not None: + mask = mask & (zone_values != nodata_values) + zone_values = zone_values[mask] if len(zone_values) > 0: results[i] = func(zone_values) start = end @@ -342,9 +404,11 @@ def _stats_dask_numpy( sum=values.dtype, count=np.int64, sum_squares=values.dtype, - squared_sum=values.dtype, ) + # Keep per-block stacked arrays for the parallel variance merge + stacked_blocks = {} + for s in basis_stats: if s == 'sum_squares' and not compute_sum_squares: continue @@ -358,6 +422,10 @@ def _stats_dask_numpy( for z, v in zip(zones_blocks, values_blocks) ] zonal_stats = da.stack(stats_by_block, allow_unknown_chunksizes=True) + + if compute_sum_squares and s in ('count', 'sum', 'sum_squares'): + stacked_blocks[s] = zonal_stats + stats_func_by_block = delayed(_DASK_STATS[s]) stats_dict[s] = da.from_delayed( stats_func_by_block(zonal_stats), shape=(np.nan,), dtype=np.float64 @@ -365,14 +433,23 @@ def _stats_dask_numpy( if 'mean' in stats_funcs: stats_dict['mean'] = _dask_mean(stats_dict['sum'], stats_dict['count']) - if 'std' in stats_funcs: - stats_dict['std'] = _dask_std( - stats_dict['sum_squares'], stats_dict['sum'] ** 2, stats_dict['count'] - ) - if 'var' in stats_funcs: - stats_dict['var'] = _dask_var( - stats_dict['sum_squares'], stats_dict['sum'] ** 2, stats_dict['count'] + + if 'std' in stats_funcs or 'var' in stats_funcs: + var_result = da.from_delayed( + delayed(_parallel_variance)( + stacked_blocks['count'], + stacked_blocks['sum'], + stacked_blocks['sum_squares'], + ), + shape=(np.nan,), dtype=np.float64, ) + if 'var' in stats_funcs: + stats_dict['var'] = var_result + if 'std' in stats_funcs: + stats_dict['std'] = da.from_delayed( + delayed(np.sqrt)(var_result), + shape=(np.nan,), dtype=np.float64, + ) # generate dask dataframe stats_df = dd.concat([dd.from_dask_array(s) for s in stats_dict.values()], axis=1, ignore_unknown_divisions=True) @@ -846,9 +923,10 @@ def _single_zone_crosstab_2d( ): # 1D flatten zone_values, i.e, original data is 2D # filter out non-finite and nodata_values - zone_values = zone_values[ - np.isfinite(zone_values) & (zone_values != nodata_values) - ] + mask = np.isfinite(zone_values) + if nodata_values is not None: + mask = mask & (zone_values != nodata_values) + zone_values = zone_values[mask] total_count = zone_values.shape[0] crosstab_dict[TOTAL_COUNT].append(total_count) @@ -877,10 +955,10 @@ def _single_zone_crosstab_3d( if cat in cat_ids: zone_cat_data = zone_values[j] # filter out non-finite and nodata_values - zone_cat_data = zone_cat_data[ - np.isfinite(zone_cat_data) - & (zone_cat_data != nodata_values) - ] + cat_mask = np.isfinite(zone_cat_data) + if nodata_values is not None: + cat_mask = cat_mask & (zone_cat_data != nodata_values) + zone_cat_data = zone_cat_data[cat_mask] crosstab_dict[cat].append(stats_func(zone_cat_data)) From a67aa8e2396865cd114caf769f0b3dc10dd2e769 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Mon, 30 Mar 2026 11:38:22 -0700 Subject: [PATCH 2/2] Add tests and fix block-level M2 for variance stability (#1090) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Block-level sum_squares now computes M2 (sum of squared deviations from block mean) instead of raw sum(x²), avoiding float64 precision loss for large values. - Updated test_stats_all_nan_zone and test_stats_nodata_wipes_zone to expect NaN from dask (no longer 0). - Added test_stats_variance_numerical_stability_1090: values near 1e8 with spread of 1, verifying dask matches numpy to 1e-6. - Added test_stats_nodata_none_no_warning_1090: confirms no FutureWarning when nodata_values=None. --- xrspatial/tests/test_zonal.py | 87 +++++++++++++++++++++++++++-------- xrspatial/zonal.py | 15 +++--- 2 files changed, 73 insertions(+), 29 deletions(-) diff --git a/xrspatial/tests/test_zonal.py b/xrspatial/tests/test_zonal.py index 8ff150e5..011de940 100644 --- a/xrspatial/tests/test_zonal.py +++ b/xrspatial/tests/test_zonal.py @@ -680,18 +680,8 @@ def test_stats_all_nan_zone(backend): 'sum': [12.0], 'count': [2], } - elif 'dask' in backend: - # dask uses nansum reduction, so count/sum of all-NaN become 0 - expected = { - 'zone': [1, 2], - 'mean': [np.nan, 6.0], - 'max': [np.nan, 7.0], - 'min': [np.nan, 5.0], - 'sum': [0.0, 12.0], - 'count': [0, 2], - } else: - # numpy keeps empty zone with NaN for every stat + # numpy and dask both return NaN for all-NaN zones expected = { 'zone': [1, 2], 'mean': [np.nan, 6.0], @@ -798,16 +788,8 @@ def test_stats_nodata_wipes_zone(backend): 'sum': [10.0], 'count': [2], } - elif 'dask' in backend: - expected = { - 'zone': [1, 2], - 'mean': [np.nan, 5.0], - 'max': [np.nan, 7.0], - 'min': [np.nan, 3.0], - 'sum': [0.0, 10.0], - 'count': [0, 2], - } else: + # numpy and dask both return NaN for zones with no valid values expected = { 'zone': [1, 2], 'mean': [np.nan, 5.0], @@ -868,6 +850,71 @@ def test_zonal_stats_inputs_unmodified(backend, data_zones, data_values_2d, resu assert_input_data_unmodified(data_values_2d, copied_data_values_2d) +@pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning") +@pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning") +@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy']) +def test_stats_variance_numerical_stability_1090(backend): + """Dask std/var should match numpy for data with large mean, small spread. + + Regression test for #1090: the naive one-pass formula + ``(Σx² − (Σx)²/n) / n`` loses precision through catastrophic + cancellation. The fix uses Chan-Golub-LeVeque parallel merge. + """ + if 'dask' in backend and not dask_array_available(): + pytest.skip("Requires Dask") + + # Values near 1e8 with a spread of 1: the naive formula would lose + # most of the significant digits in float64. + zones_data = np.array([[1, 1, 1, 1, 1, 1]]) + values_data = np.array([[1e8, 1e8 + 1, 1e8 + 2, + 1e8 + 3, 1e8 + 4, 1e8 + 5]], dtype=np.float64) + + zones = create_test_raster(zones_data, backend, chunks=(1, 3)) + values = create_test_raster(values_data, backend, chunks=(1, 3)) + + df_result = stats(zones=zones, values=values, + stats_funcs=['mean', 'std', 'var']) + + if hasattr(df_result, 'compute'): + df_result = df_result.compute() + + # Reference: population variance of [0,1,2,3,4,5] = 35/12 ≈ 2.9167 + expected_var = np.var(np.arange(6, dtype=np.float64)) + expected_std = np.std(np.arange(6, dtype=np.float64)) + + actual_var = float(df_result['var'].iloc[0]) + actual_std = float(df_result['std'].iloc[0]) + + assert abs(actual_var - expected_var) < 1e-6, ( + f"var={actual_var}, expected={expected_var}" + ) + assert abs(actual_std - expected_std) < 1e-6, ( + f"std={actual_std}, expected={expected_std}" + ) + + +def test_stats_nodata_none_no_warning_1090(): + """Passing nodata_values=None (the default) should not trigger warnings. + + Regression test for #1090: ``zone_values != None`` triggered a numpy + FutureWarning. + """ + import warnings + + zones_data = np.array([[1, 1], [2, 2]], dtype=float) + values_data = np.array([[1.0, 2.0], [3.0, 4.0]]) + zones = xr.DataArray(zones_data) + values = xr.DataArray(values_data) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + df = stats(zones=zones, values=values, nodata_values=None) + + assert len(df) == 2 + assert float(df['mean'].iloc[0]) == 1.5 + assert float(df['mean'].iloc[1]) == 3.5 + + @pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning") @pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning") @pytest.mark.parametrize("backend", ['numpy', 'dask+numpy']) diff --git a/xrspatial/zonal.py b/xrspatial/zonal.py index 8555c614..db5f8709 100644 --- a/xrspatial/zonal.py +++ b/xrspatial/zonal.py @@ -198,7 +198,7 @@ def _stats_majority(data): min=lambda z: z.min(), sum=lambda z: z.sum(), count=lambda z: _stats_count(z), - sum_squares=lambda z: (z**2).sum() + sum_squares=lambda z: ((z - z.mean()) ** 2).sum() # block-level M2 ) @@ -227,15 +227,13 @@ def _dask_mean(sums, counts): # noqa return sums / counts -def _parallel_variance(block_counts, block_sums, block_sum_squares): +def _parallel_variance(block_counts, block_sums, block_m2s): """Population variance via Chan-Golub-LeVeque parallel merge. - Each input is (n_blocks, n_zones). Returns (n_zones,) variance, + Each input is (n_blocks, n_zones). ``block_m2s`` contains + per-block M2 values (sum of squared deviations from the block mean), + NOT raw sum-of-squares. Returns (n_zones,) population variance, with NaN for zones that have no valid values in any block. - - This avoids the naive ``(Σx² − (Σx)²/n) / n`` formula whose - subtraction can lose most significant digits when the mean is - large relative to the standard deviation. """ n_blocks = block_counts.shape[0] n_zones = block_counts.shape[1] @@ -247,14 +245,13 @@ def _parallel_variance(block_counts, block_sums, block_sum_squares): for i in range(n_blocks): nc = np.asarray(block_counts[i], dtype=np.float64) sc = np.asarray(block_sums[i], dtype=np.float64) - sqc = np.asarray(block_sum_squares[i], dtype=np.float64) + m2_b = np.asarray(block_m2s[i], dtype=np.float64) has_data = np.isfinite(nc) & (nc > 0) nc_safe = np.where(has_data, nc, 1.0) # avoid /0 with np.errstate(invalid='ignore', divide='ignore'): mean_b = sc / nc_safe - m2_b = sqc - sc ** 2 / nc_safe # block-internal M2 nc = np.where(has_data, nc, 0.0) n_ab = n_acc + nc